diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatUtils.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatUtils.java index 2dd2a8c6aa52..9c9b080d6500 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatUtils.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatUtils.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.plugin.base.type.DecodedTimestamp; +import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.DecimalConversions; @@ -44,7 +45,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.hive.formats.HiveFormatsErrorCode.HIVE_INVALID_METADATA; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.Decimals.overflows; import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; @@ -264,18 +265,15 @@ public static List getTimestampFormatsSchemaProperty(Map char c = property.charAt(position); if (c == TIMESTAMP_FORMATS_ESCAPE) { // the next character must be an escape or separator - checkArgument( - position + 1 < property.length(), - "Invalid '%s' property value '%s': unterminated escape at end of value", - TIMESTAMP_FORMATS_KEY, - property); + if (position + 1 >= property.length()) { + throw new TrinoException(HIVE_INVALID_METADATA, + "Invalid '%s' property value '%s': unterminated escape at end of value".formatted(TIMESTAMP_FORMATS_KEY, property)); + } char nextCharacter = property.charAt(position + 1); - checkArgument( - nextCharacter == TIMESTAMP_FORMATS_SEPARATOR || nextCharacter == TIMESTAMP_FORMATS_ESCAPE, - "Invalid '%s' property value '%s': Illegal escaped character at %s", - TIMESTAMP_FORMATS_KEY, - property, - position); + if (nextCharacter != TIMESTAMP_FORMATS_SEPARATOR && nextCharacter != TIMESTAMP_FORMATS_ESCAPE) { + throw new TrinoException(HIVE_INVALID_METADATA, + "Invalid '%s' property value '%s': Illegal escaped character at %s".formatted(TIMESTAMP_FORMATS_KEY, property, position)); + } buffer.append(nextCharacter); position++; diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatsErrorCode.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatsErrorCode.java new file mode 100644 index 000000000000..98d4b3532dd9 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/HiveFormatsErrorCode.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats; + +import io.trino.spi.ErrorCode; +import io.trino.spi.ErrorCodeSupplier; +import io.trino.spi.ErrorType; + +import static io.trino.spi.ErrorType.EXTERNAL; + +// these error codes must match the error codes in HiveErrorCode +public enum HiveFormatsErrorCode + implements ErrorCodeSupplier +{ + HIVE_INVALID_METADATA(12, EXTERNAL), + /**/; + + private final ErrorCode errorCode; + + HiveFormatsErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0100_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/TestHiveFormatUtils.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/TestHiveFormatUtils.java index dc74158c1cfa..24d26b6f4f08 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/TestHiveFormatUtils.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/TestHiveFormatUtils.java @@ -13,11 +13,17 @@ */ package io.trino.hive.formats; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; import java.time.LocalDate; +import static io.trino.hive.formats.HiveFormatUtils.TIMESTAMP_FORMATS_KEY; +import static io.trino.hive.formats.HiveFormatUtils.getTimestampFormatsSchemaProperty; import static io.trino.hive.formats.HiveFormatUtils.parseHiveDate; +import static io.trino.hive.formats.HiveFormatsErrorCode.HIVE_INVALID_METADATA; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; import static org.assertj.core.api.Assertions.assertThat; public class TestHiveFormatUtils @@ -30,4 +36,23 @@ public void test() assertThat(parseHiveDate("-5877641-06-23")).isEqualTo(LocalDate.of(-5877641, 6, 23)); assertThat(parseHiveDate("1986-01-33")).isEqualTo(LocalDate.of(1986, 2, 2)); } + + @Test + public void testTimestampFormatEscaping() + { + assertTrinoExceptionThrownBy(() -> getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\"))) + .hasErrorCode(HIVE_INVALID_METADATA) + .hasMessageContaining("unterminated escape"); + assertTrinoExceptionThrownBy(() -> getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\neither backslash nor comma"))) + .hasErrorCode(HIVE_INVALID_METADATA) + .hasMessageContaining("Illegal escaped character"); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\\\"))).isEqualTo(ImmutableList.of("\\")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "xx\\\\"))).isEqualTo(ImmutableList.of("xx\\")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\\\yy"))).isEqualTo(ImmutableList.of("\\yy")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "xx\\\\yy"))).isEqualTo(ImmutableList.of("xx\\yy")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\,"))).isEqualTo(ImmutableList.of(",")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "aa\\,"))).isEqualTo(ImmutableList.of("aa,")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "\\,bb"))).isEqualTo(ImmutableList.of(",bb")); + assertThat(getTimestampFormatsSchemaProperty(ImmutableMap.of(TIMESTAMP_FORMATS_KEY, "aa\\,bb"))).isEqualTo(ImmutableList.of("aa,bb")); + } }