diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java index 18e6251ff93..aba96d2ee22 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.trino.plugin.hive.HivePageSourceProvider.BucketAdaptation; import io.trino.plugin.hive.HivePageSourceProvider.ColumnMapping; +import io.trino.plugin.hive.coercions.CharCoercer; import io.trino.plugin.hive.coercions.DoubleToFloatCoercer; import io.trino.plugin.hive.coercions.FloatToDoubleCoercer; import io.trino.plugin.hive.coercions.IntegerNumberToVarcharCoercer; @@ -43,6 +44,7 @@ import io.trino.spi.connector.RecordCursor; import io.trino.spi.metrics.Metrics; import io.trino.spi.type.ArrayType; +import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; @@ -313,6 +315,12 @@ private static Optional> createCoercer(TypeManager typeMa } return Optional.empty(); } + if (fromType instanceof CharType fromCharType && toType instanceof CharType toCharType) { + if (narrowerThan(toCharType, fromCharType)) { + return Optional.of(new CharCoercer(fromCharType, toCharType)); + } + return Optional.empty(); + } if (fromHiveType.equals(HIVE_BYTE) && (toHiveType.equals(HIVE_SHORT) || toHiveType.equals(HIVE_INT) || toHiveType.equals(HIVE_LONG))) { return Optional.of(new IntegerNumberUpscaleCoercer<>(fromType, toType)); } @@ -372,6 +380,13 @@ public static boolean narrowerThan(VarcharType first, VarcharType second) return first.getBoundedLength() < second.getBoundedLength(); } + public static boolean narrowerThan(CharType first, CharType second) + { + requireNonNull(first, "first is null"); + requireNonNull(second, "second is null"); + return first.getLength() < second.getLength(); + } + private static class ListCoercer implements Function { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CharCoercer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CharCoercer.java new file mode 100644 index 00000000000..54df1a03794 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/coercions/CharCoercer.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.coercions; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.CharType; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.hive.HivePageSource.narrowerThan; +import static io.trino.spi.type.Chars.truncateToLengthAndTrimSpaces; + +public class CharCoercer + extends TypeCoercer +{ + public CharCoercer(CharType fromType, CharType toType) + { + super(fromType, toType); + checkArgument(narrowerThan(toType, fromType), "Coercer to a wider char type should not be required"); + } + + @Override + protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position) + { + Slice value = fromType.getSlice(block, position); + toType.writeSlice(blockBuilder, truncateToLengthAndTrimSpaces(value, toType)); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveCoercionPolicy.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveCoercionPolicy.java index 402cf1365dc..ac9e4d35388 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveCoercionPolicy.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveCoercionPolicy.java @@ -19,6 +19,7 @@ import io.trino.plugin.hive.type.ListTypeInfo; import io.trino.plugin.hive.type.MapTypeInfo; import io.trino.plugin.hive.type.StructTypeInfo; +import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; @@ -63,6 +64,9 @@ private boolean canCoerce(HiveType fromHiveType, HiveType toHiveType, HiveTimest toHiveType.equals(HIVE_INT) || toHiveType.equals(HIVE_LONG); } + if (fromType instanceof CharType) { + return toType instanceof CharType; + } if (toType instanceof VarcharType) { return fromHiveType.equals(HIVE_BYTE) || fromHiveType.equals(HIVE_SHORT) || fromHiveType.equals(HIVE_INT) || fromHiveType.equals(HIVE_LONG) || fromType instanceof DecimalType; } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestHiveCoercion.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestHiveCoercion.java index dfe16d0cb1f..89ab0973558 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestHiveCoercion.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestHiveCoercion.java @@ -50,6 +50,7 @@ import static java.lang.String.format; import static java.sql.JDBCType.ARRAY; import static java.sql.JDBCType.BIGINT; +import static java.sql.JDBCType.CHAR; import static java.sql.JDBCType.DECIMAL; import static java.sql.JDBCType.DOUBLE; import static java.sql.JDBCType.FLOAT; @@ -108,6 +109,8 @@ protected void doTestHiveCoercion(HiveTableDefinition tableDefinition) "long_decimal_to_bounded_varchar", "varchar_to_bigger_varchar", "varchar_to_smaller_varchar", + "char_to_bigger_char", + "char_to_smaller_char", "id"); Function>> expected = engine -> expectedValuesForEngineProvider(engine, tableName, decimalToFloatVal, floatToDecimalVal); @@ -162,6 +165,8 @@ protected void insertTableRows(String tableName, String floatToDoubleType) " DECIMAL '12345678.123456123456', " + " 'abc', " + " 'abc', " + + " 'abc', " + + " 'abc', " + " 1), " + "(" + " CAST(ROW (NULL, 1, -100, -2323, -12345, 2) AS ROW(keep VARCHAR, ti2si TINYINT, si2int SMALLINT, int2bi INTEGER, bi2vc BIGINT, lower2uppercase BIGINT)), " + @@ -190,6 +195,8 @@ protected void insertTableRows(String tableName, String floatToDoubleType) " DECIMAL '-12345678.123456123456', " + " '\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0', " + " '\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0', " + + " '\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0', " + + " '\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0', " + " 1)", tableName, floatToDoubleType)); @@ -210,7 +217,7 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { } return ImmutableMap.>builder() - .put("row_to_row", Arrays.asList( + .put("row_to_row", ImmutableList.of( engine == Engine.TRINO ? rowBuilder() .addField("keep", "as is") @@ -232,7 +239,7 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { .addField("lower2uppercase", 2L) .build() : String.format("{\"keep\":null,\"ti2si\":1,\"si2int\":-100,\"int2bi\":-2323,\"bi2vc\":\"-12345\",%s}", hiveValueForCaseChangeField))) - .put("list_to_list", Arrays.asList( + .put("list_to_list", ImmutableList.of( engine == Engine.TRINO ? ImmutableList.of(rowBuilder() .addField("ti2int", 2) @@ -247,7 +254,7 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { .addField("bi2vc", "-12345") .build()) : "[{\"ti2int\":-2,\"si2bi\":101,\"bi2vc\":\"-12345\"}]")) - .put("map_to_map", Arrays.asList( + .put("map_to_map", ImmutableList.of( engine == Engine.TRINO ? ImmutableMap.of(2, rowBuilder() .addField("ti2bi", -3L) @@ -264,70 +271,76 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) { .addField("add", null) .build()) : "{-2:{\"ti2bi\":null,\"int2bi\":-2323,\"float2double\":-1.5,\"add\":null}}")) - .put("tinyint_to_smallint", Arrays.asList( + .put("tinyint_to_smallint", ImmutableList.of( -1, 1)) - .put("tinyint_to_int", Arrays.asList( + .put("tinyint_to_int", ImmutableList.of( 2, -2)) .put("tinyint_to_bigint", Arrays.asList( -3L, null)) - .put("smallint_to_int", Arrays.asList( + .put("smallint_to_int", ImmutableList.of( 100, -100)) - .put("smallint_to_bigint", Arrays.asList( + .put("smallint_to_bigint", ImmutableList.of( -101L, 101L)) - .put("int_to_bigint", Arrays.asList( + .put("int_to_bigint", ImmutableList.of( 2323L, -2323L)) - .put("bigint_to_varchar", Arrays.asList( + .put("bigint_to_varchar", ImmutableList.of( "12345", "-12345")) - .put("float_to_double", Arrays.asList( + .put("float_to_double", ImmutableList.of( 0.5, -1.5)) - .put("double_to_float", Arrays.asList(0.5, -1.5)) - .put("shortdecimal_to_shortdecimal", Arrays.asList( + .put("double_to_float", ImmutableList.of(0.5, -1.5)) + .put("shortdecimal_to_shortdecimal", ImmutableList.of( new BigDecimal("12345678.1200"), new BigDecimal("-12345678.1200"))) - .put("shortdecimal_to_longdecimal", Arrays.asList( + .put("shortdecimal_to_longdecimal", ImmutableList.of( new BigDecimal("12345678.1200"), new BigDecimal("-12345678.1200"))) - .put("longdecimal_to_shortdecimal", Arrays.asList( + .put("longdecimal_to_shortdecimal", ImmutableList.of( new BigDecimal("12345678.12"), new BigDecimal("-12345678.12"))) - .put("longdecimal_to_longdecimal", Arrays.asList( + .put("longdecimal_to_longdecimal", ImmutableList.of( new BigDecimal("12345678.12345612345600"), new BigDecimal("-12345678.12345612345600"))) - .put("float_to_decimal", Arrays.asList(new BigDecimal(floatToDecimalVal), new BigDecimal("-" + floatToDecimalVal))) - .put("double_to_decimal", Arrays.asList(new BigDecimal("12345.12345"), new BigDecimal("-12345.12345"))) - .put("decimal_to_float", Arrays.asList( + .put("float_to_decimal", ImmutableList.of(new BigDecimal(floatToDecimalVal), new BigDecimal("-" + floatToDecimalVal))) + .put("double_to_decimal", ImmutableList.of(new BigDecimal("12345.12345"), new BigDecimal("-12345.12345"))) + .put("decimal_to_float", ImmutableList.of( Float.parseFloat(decimalToFloatVal), -Float.parseFloat(decimalToFloatVal))) - .put("decimal_to_double", Arrays.asList( + .put("decimal_to_double", ImmutableList.of( 12345.12345, -12345.12345)) - .put("short_decimal_to_varchar", Arrays.asList( + .put("short_decimal_to_varchar", ImmutableList.of( "12345.12345", "-12345.12345")) - .put("long_decimal_to_varchar", Arrays.asList( + .put("long_decimal_to_varchar", ImmutableList.of( "12345678.123456123456", "-12345678.123456123456")) - .put("short_decimal_to_bounded_varchar", Arrays.asList( + .put("short_decimal_to_bounded_varchar", ImmutableList.of( "12345.12345", "12345.12345")) - .put("long_decimal_to_bounded_varchar", Arrays.asList( + .put("long_decimal_to_bounded_varchar", ImmutableList.of( "12345678.123456123456", "-12345678.123456123456")) - .put("varchar_to_bigger_varchar", Arrays.asList( + .put("varchar_to_bigger_varchar", ImmutableList.of( "abc", "\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0")) - .put("varchar_to_smaller_varchar", Arrays.asList( + .put("varchar_to_smaller_varchar", ImmutableList.of( + "ab", + "\uD83D\uDCB0\uD83D\uDCB0")) + .put("char_to_bigger_char", ImmutableList.of( + "abc ", + "\uD83D\uDCB0\uD83D\uDCB0\uD83D\uDCB0 ")) + .put("char_to_smaller_char", ImmutableList.of( "ab", "\uD83D\uDCB0\uD83D\uDCB0")) - .put("id", Arrays.asList( + .put("id", ImmutableList.of( 1, 1)) .buildOrThrow(); @@ -549,6 +562,8 @@ private void assertProperAlteredTableSchema(String tableName) row("long_decimal_to_bounded_varchar", "varchar(30)"), row("varchar_to_bigger_varchar", "varchar(4)"), row("varchar_to_smaller_varchar", "varchar(2)"), + row("char_to_bigger_char", "char(4)"), + row("char_to_smaller_char", "char(2)"), row("id", "bigint")); } @@ -593,6 +608,8 @@ private void assertColumnTypes( .put("long_decimal_to_bounded_varchar", VARCHAR) .put("varchar_to_bigger_varchar", VARCHAR) .put("varchar_to_smaller_varchar", VARCHAR) + .put("char_to_bigger_char", CHAR) + .put("char_to_smaller_char", CHAR) .put("id", BIGINT) .put("nested_field", BIGINT) .buildOrThrow(); @@ -631,6 +648,8 @@ private static void alterTableColumnTypes(String tableName) onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN long_decimal_to_bounded_varchar long_decimal_to_bounded_varchar varchar(30)", tableName)); onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN varchar_to_bigger_varchar varchar_to_bigger_varchar varchar(4)", tableName)); onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN varchar_to_smaller_varchar varchar_to_smaller_varchar varchar(2)", tableName)); + onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN char_to_bigger_char char_to_bigger_char char(4)", tableName)); + onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN char_to_smaller_char char_to_smaller_char char(2)", tableName)); } protected static TableInstance mutableTableInstanceOf(TableDefinition tableDefinition) @@ -689,7 +708,7 @@ private static List column(List rows, int sqlColumnIndex) private static List> extract(List arrays) { return arrays.stream() - .map(trinoArray -> Arrays.asList((Object[]) trinoArray.getArray())) + .map(trinoArray -> ImmutableList.copyOf((Object[]) trinoArray.getArray())) .collect(toImmutableList()); } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnPartitionedTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnPartitionedTable.java index 566434bae63..88f41358fb6 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnPartitionedTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnPartitionedTable.java @@ -96,7 +96,9 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui " short_decimal_to_bounded_varchar DECIMAL(10,5)," + " long_decimal_to_bounded_varchar DECIMAL(20,12)," + " varchar_to_bigger_varchar VARCHAR(3)," + - " varchar_to_smaller_varchar VARCHAR(3)" + + " varchar_to_smaller_varchar VARCHAR(3)," + + " char_to_bigger_char CHAR(3)," + + " char_to_smaller_char CHAR(3)" + ") " + "PARTITIONED BY (id BIGINT) " + rowFormat.map(s -> format("ROW FORMAT %s ", s)).orElse("") + diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnUnpartitionedTable.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnUnpartitionedTable.java index bd2d9bdc227..442faa342d4 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnUnpartitionedTable.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestHiveCoercionOnUnpartitionedTable.java @@ -70,6 +70,8 @@ short_decimal_to_bounded_varchar DECIMAL(10,5), long_decimal_to_bounded_varchar DECIMAL(20,12), varchar_to_bigger_varchar VARCHAR(3), varchar_to_smaller_varchar VARCHAR(3), + char_to_bigger_char CHAR(3), + char_to_smaller_char CHAR(3), id BIGINT) STORED AS\s""" + fileFormat); }