diff --git a/lib/trino-hive-formats/pom.xml b/lib/trino-hive-formats/pom.xml index 64e917a421cf..e658c72d4cd5 100644 --- a/lib/trino-hive-formats/pom.xml +++ b/lib/trino-hive-formats/pom.xml @@ -39,6 +39,11 @@ true + + io.airlift + log + + io.airlift slice @@ -81,6 +86,11 @@ joda-time + + org.apache.avro + avro + + org.gaul modernizer-maven-annotations @@ -113,6 +123,13 @@ test + + io.trino + trino-main + test-jar + test + + io.trino trino-testing-services diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/UnionToRowCoercionUtils.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/UnionToRowCoercionUtils.java new file mode 100644 index 000000000000..69ef8cc4edd0 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/UnionToRowCoercionUtils.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.List; + +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.spi.type.TypeSignatureParameter.namedField; + +public final class UnionToRowCoercionUtils +{ + public static final String UNION_FIELD_TAG_NAME = "tag"; + public static final String UNION_FIELD_FIELD_PREFIX = "field"; + public static final Type UNION_FIELD_TAG_TYPE = TINYINT; + + private UnionToRowCoercionUtils() {} + + public static RowType rowTypeForUnionOfTypes(List types) + { + ImmutableList.Builder fields = ImmutableList.builder() + .add(RowType.field(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE)); + for (int i = 0; i < types.size(); i++) { + fields.add(RowType.field(UNION_FIELD_FIELD_PREFIX + i, types.get(i))); + } + return RowType.from(fields.build()); + } + + public static TypeSignature rowTypeSignatureForUnionOfTypes(List typeSignatures) + { + ImmutableList.Builder fields = ImmutableList.builder(); + fields.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); + for (int i = 0; i < typeSignatures.size(); i++) { + fields.add(namedField(UNION_FIELD_FIELD_PREFIX + i, typeSignatures.get(i))); + } + return TypeSignature.rowType(fields.build()); + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java new file mode 100644 index 000000000000..ace474bb58e2 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java @@ -0,0 +1,177 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.filesystem.TrinoInputFile; +import io.trino.hive.formats.TrinoDataInputStream; +import io.trino.spi.Page; +import org.apache.avro.AvroRuntimeException; +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.SeekableInput; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class AvroFileReader + implements Closeable +{ + private final TrinoDataInputStream input; + private final AvroPageDataReader dataReader; + private final DataFileReader> fileReader; + private Page nextPage; + private final OptionalLong end; + + public AvroFileReader( + TrinoInputFile inputFile, + Schema schema, + AvroTypeManager avroTypeManager) + throws IOException, AvroTypeException + { + this(inputFile, schema, avroTypeManager, 0, OptionalLong.empty()); + } + + public AvroFileReader( + TrinoInputFile inputFile, + Schema schema, + AvroTypeManager avroTypeManager, + long offset, + OptionalLong length) + throws IOException, AvroTypeException + { + requireNonNull(inputFile, "inputFile is null"); + requireNonNull(schema, "schema is null"); + requireNonNull(avroTypeManager, "avroTypeManager is null"); + long fileSize = inputFile.length(); + + verify(offset >= 0, "offset is negative"); + verify(offset < inputFile.length(), "offset is greater than data size"); + length.ifPresent(lengthLong -> verify(lengthLong >= 1, "length must be at least 1")); + end = length.stream().map(l -> l + offset).findFirst(); + end.ifPresent(endLong -> verify(endLong <= fileSize, "offset plus length is greater than data size")); + input = new TrinoDataInputStream(inputFile.newStream()); + dataReader = new AvroPageDataReader(schema, avroTypeManager); + try { + fileReader = new DataFileReader<>(new TrinoDataInputStreamAsAvroSeekableInput(input, fileSize), dataReader); + fileReader.sync(offset); + } + catch (AvroPageDataReader.UncheckedAvroTypeException runtimeWrapper) { + // Avro Datum Reader interface can't throw checked exceptions when initialized by the file reader, + // so the exception is wrapped in a runtime exception that must be unwrapped + throw runtimeWrapper.getAvroTypeException(); + } + avroTypeManager.configure(fileReader.getMetaKeys().stream().collect(toImmutableMap(Function.identity(), fileReader::getMeta))); + } + + public long getCompletedBytes() + { + return input.getReadBytes(); + } + + public long getReadTimeNanos() + { + return input.getReadTimeNanos(); + } + + public boolean hasNext() + throws IOException + { + loadNextPageIfNecessary(); + return nextPage != null; + } + + public Page next() + throws IOException + { + if (!hasNext()) { + throw new IOException("No more pages available from Avro file"); + } + Page result = nextPage; + nextPage = null; + return result; + } + + private void loadNextPageIfNecessary() + throws IOException + { + while (nextPage == null && (end.isEmpty() || !fileReader.pastSync(end.getAsLong())) && fileReader.hasNext()) { + try { + nextPage = fileReader.next().orElse(null); + } + catch (AvroRuntimeException e) { + throw new IOException(e); + } + } + if (nextPage == null) { + nextPage = dataReader.flush().orElse(null); + } + } + + @Override + public void close() + throws IOException + { + fileReader.close(); + } + + private record TrinoDataInputStreamAsAvroSeekableInput(TrinoDataInputStream inputStream, long fileSize) + implements SeekableInput + { + TrinoDataInputStreamAsAvroSeekableInput + { + requireNonNull(inputStream, "inputStream is null"); + } + + @Override + public void seek(long p) + throws IOException + { + inputStream.seek(p); + } + + @Override + public long tell() + throws IOException + { + return inputStream.getPos(); + } + + @Override + public long length() + { + return fileSize; + } + + @Override + public int read(byte[] b, int off, int len) + throws IOException + { + return inputStream.read(b, off, len); + } + + @Override + public void close() + throws IOException + { + inputStream.close(); + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java new file mode 100644 index 000000000000..4c4c08f7e052 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java @@ -0,0 +1,1052 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.SingleRowBlockWriter; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import org.apache.avro.Resolver; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.io.BinaryDecoder; +import org.apache.avro.io.DatumReader; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.io.FastReaderBuilder; +import org.apache.avro.io.parsing.ResolvingGrammarGenerator; +import org.apache.avro.util.internal.Accessor; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.IntFunction; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_TYPE; +import static io.trino.hive.formats.avro.AvroTypeUtils.isSimpleNullableUnion; +import static io.trino.hive.formats.avro.AvroTypeUtils.typeFromAvro; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.Float.floatToRawIntBits; +import static java.util.Objects.requireNonNull; + +public class AvroPageDataReader + implements DatumReader> +{ + // same limit as org.apache.avro.io.BinaryDecoder + private static final long MAX_ARRAY_SIZE = (long) Integer.MAX_VALUE - 8L; + + private final Schema readerSchema; + private Schema writerSchema; + private final PageBuilder pageBuilder; + private RowBlockBuildingDecoder rowBlockBuildingDecoder; + private final AvroTypeManager typeManager; + + public AvroPageDataReader(Schema readerSchema, AvroTypeManager typeManager) + throws AvroTypeException + { + this.readerSchema = requireNonNull(readerSchema, "readerSchema is null"); + writerSchema = this.readerSchema; + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + try { + Type readerSchemaType = typeFromAvro(this.readerSchema, typeManager); + verify(readerSchemaType instanceof RowType, "Root Avro type must be a row"); + pageBuilder = new PageBuilder(readerSchemaType.getTypeParameters()); + initialize(); + } + catch (org.apache.avro.AvroTypeException e) { + throw new AvroTypeException(e); + } + } + + private void initialize() + throws AvroTypeException + { + verify(readerSchema.getType() == Schema.Type.RECORD, "Avro schema for page reader must be record"); + verify(writerSchema.getType() == Schema.Type.RECORD, "File Avro schema for page reader must be record"); + rowBlockBuildingDecoder = new RowBlockBuildingDecoder(writerSchema, readerSchema, typeManager); + } + + @Override + public void setSchema(Schema schema) + { + requireNonNull(schema, "schema is null"); + if (schema != writerSchema) { + writerSchema = schema; + try { + initialize(); + } + catch (org.apache.avro.AvroTypeException e) { + throw new UncheckedAvroTypeException(new AvroTypeException(e)); + } + catch (AvroTypeException e) { + throw new UncheckedAvroTypeException(e); + } + } + } + + @Override + public Optional read(Optional ignoredReuse, Decoder decoder) + throws IOException + { + Optional page = Optional.empty(); + rowBlockBuildingDecoder.decodeIntoPageBuilder(decoder, pageBuilder); + if (pageBuilder.isFull()) { + page = Optional.of(pageBuilder.build()); + pageBuilder.reset(); + } + return page; + } + + public Optional flush() + { + if (!pageBuilder.isEmpty()) { + Optional lastPage = Optional.of(pageBuilder.build()); + pageBuilder.reset(); + return lastPage; + } + return Optional.empty(); + } + + private abstract static class BlockBuildingDecoder + { + protected abstract void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException; + } + + private static BlockBuildingDecoder createBlockBuildingDecoderForAction(Resolver.Action action, AvroTypeManager typeManager) + throws AvroTypeException + { + Optional> consumer = typeManager.overrideBuildingFunctionForSchema(action.reader); + if (consumer.isPresent()) { + return new UserDefinedBlockBuildingDecoder(action.reader, action.writer, consumer.get()); + } + return switch (action.type) { + case DO_NOTHING -> switch (action.reader.getType()) { + case NULL -> NullBlockBuildingDecoder.INSTANCE; + case BOOLEAN -> BooleanBlockBuildingDecoder.INSTANCE; + case INT -> IntBlockBuildingDecoder.INSTANCE; + case LONG -> new LongBlockBuildingDecoder(); + case FLOAT -> new FloatBlockBuildingDecoder(); + case DOUBLE -> new DoubleBlockBuildingDecoder(); + case STRING -> StringBlockBuildingDecoder.INSTANCE; + case BYTES -> BytesBlockBuildingDecoder.INSTANCE; + case FIXED -> new FixedBlockBuildingDecoder(action.reader.getFixedSize()); + // these reader types covered by special action types + case ENUM, ARRAY, MAP, RECORD, UNION -> throw new IllegalStateException("Do Nothing action type not compatible with reader schema type " + action.reader.getType()); + }; + case PROMOTE -> switch (action.reader.getType()) { + // only certain types valid to promote into as determined by org.apache.avro.Resolver.Promote.isValid + case LONG -> new LongBlockBuildingDecoder(getLongPromotionFunction(action.writer)); + case FLOAT -> new FloatBlockBuildingDecoder(getFloatPromotionFunction(action.writer)); + case DOUBLE -> new DoubleBlockBuildingDecoder(getDoublePromotionFunction(action.writer)); + case STRING -> { + if (action.writer.getType() == Schema.Type.BYTES) { + yield StringBlockBuildingDecoder.INSTANCE; + } + throw new AvroTypeException("Unable to promote to String from type " + action.writer.getType()); + } + case BYTES -> { + if (action.writer.getType() == Schema.Type.STRING) { + yield BytesBlockBuildingDecoder.INSTANCE; + } + throw new AvroTypeException("Unable to promote to Bytes from type " + action.writer.getType()); + } + case NULL, BOOLEAN, INT, FIXED, ENUM, ARRAY, MAP, RECORD, UNION -> + throw new AvroTypeException("Promotion action not allowed for reader schema type " + action.reader.getType()); + }; + case CONTAINER -> switch (action.reader.getType()) { + case ARRAY -> new ArrayBlockBuildingDecoder((Resolver.Container) action, typeManager); + case MAP -> new MapBlockBuildingDecoder((Resolver.Container) action, typeManager); + default -> throw new AvroTypeException("Not possible to have container action type with non container reader schema " + action.reader.getType()); + }; + case RECORD -> new RowBlockBuildingDecoder(action, typeManager); + case ENUM -> new EnumBlockBuildingDecoder((Resolver.EnumAdjust) action); + case WRITER_UNION -> { + if (isSimpleNullableUnion(action.reader)) { + yield new WriterUnionBlockBuildingDecoder((Resolver.WriterUnion) action, typeManager); + } + else { + yield new WriterUnionCoercedIntoRowBlockBuildingDecoder((Resolver.WriterUnion) action, typeManager); + } + } + case READER_UNION -> { + if (isSimpleNullableUnion(action.reader)) { + yield createBlockBuildingDecoderForAction(((Resolver.ReaderUnion) action).actualAction, typeManager); + } + else { + yield new ReaderUnionCoercedIntoRowBlockBuildingDecoder((Resolver.ReaderUnion) action, typeManager); + } + } + case ERROR -> throw new AvroTypeException("Resolution action returned with error " + action); + case SKIP -> throw new IllegalStateException("Skips filtered by row step"); + }; + } + + // Different plugins may have different Avro Schema to Type mappings + // that are currently transforming GenericDatumReader returned objects into their target type during the record reading process + // This block building decoder allows plugin writers to port that code directly and use within this reader + // This mechanism is used to enhance Avro longs into timestamp types according to schema metadata + private static class UserDefinedBlockBuildingDecoder + extends BlockBuildingDecoder + { + private final BiConsumer userBuilderFunction; + private final DatumReader datumReader; + + public UserDefinedBlockBuildingDecoder(Schema readerSchema, Schema writerSchema, BiConsumer userBuilderFunction) + throws AvroTypeException + { + requireNonNull(readerSchema, "readerSchema is null"); + requireNonNull(writerSchema, "writerSchema is null"); + try { + FastReaderBuilder fastReaderBuilder = new FastReaderBuilder(new GenericData()); + datumReader = fastReaderBuilder.createDatumReader(writerSchema, readerSchema); + } + catch (IOException ioException) { + // IOException only thrown when default encoded in schema is unable to be re-serialized into bytes with proper typing + // translate into type exception + throw new AvroTypeException("Unable to decode default value in schema " + readerSchema, ioException); + } + this.userBuilderFunction = requireNonNull(userBuilderFunction, "userBuilderFunction is null"); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + userBuilderFunction.accept(builder, datumReader.read(null, decoder)); + } + } + + private static class NullBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final NullBlockBuildingDecoder INSTANCE = new NullBlockBuildingDecoder(); + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decoder.readNull(); + builder.appendNull(); + } + } + + private static class BooleanBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final BooleanBlockBuildingDecoder INSTANCE = new BooleanBlockBuildingDecoder(); + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BOOLEAN.writeBoolean(builder, decoder.readBoolean()); + } + } + + private static class IntBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final IntBlockBuildingDecoder INSTANCE = new IntBlockBuildingDecoder(); + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + INTEGER.writeLong(builder, decoder.readInt()); + } + } + + private static class LongBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final LongIoFunction DEFAULT_EXTRACT_LONG = Decoder::readLong; + private final LongIoFunction extractLong; + + public LongBlockBuildingDecoder() + { + this(DEFAULT_EXTRACT_LONG); + } + + public LongBlockBuildingDecoder(LongIoFunction extractLong) + { + this.extractLong = requireNonNull(extractLong, "extractLong is null"); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BIGINT.writeLong(builder, extractLong.apply(decoder)); + } + } + + private static class FloatBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final FloatIoFunction DEFAULT_EXTRACT_FLOAT = Decoder::readFloat; + private final FloatIoFunction extractFloat; + + public FloatBlockBuildingDecoder() + { + this(DEFAULT_EXTRACT_FLOAT); + } + + public FloatBlockBuildingDecoder(FloatIoFunction extractFloat) + { + this.extractFloat = requireNonNull(extractFloat, "extractFloat is null"); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + REAL.writeLong(builder, floatToRawIntBits(extractFloat.apply(decoder))); + } + } + + private static class DoubleBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final DoubleIoFunction DEFAULT_EXTRACT_DOUBLE = Decoder::readDouble; + private final DoubleIoFunction extractDouble; + + public DoubleBlockBuildingDecoder() + { + this(DEFAULT_EXTRACT_DOUBLE); + } + + public DoubleBlockBuildingDecoder(DoubleIoFunction extractDouble) + { + this.extractDouble = requireNonNull(extractDouble, "extractDouble is null"); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + DOUBLE.writeDouble(builder, extractDouble.apply(decoder)); + } + } + + private static class StringBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final StringBlockBuildingDecoder INSTANCE = new StringBlockBuildingDecoder(); + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read String type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro String with size greater than %s. Found String size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + VARCHAR.writeSlice(builder, Slices.wrappedBuffer(bytes)); + } + } + + private static class BytesBlockBuildingDecoder + extends BlockBuildingDecoder + { + private static final BytesBlockBuildingDecoder INSTANCE = new BytesBlockBuildingDecoder(); + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + // it is only possible to read Bytes type when underlying write type is String or Bytes + // both have the same encoding, so coercion is a no-op + long size = decoder.readLong(); + if (size > MAX_ARRAY_SIZE) { + throw new IOException("Unable to read avro Bytes with size greater than %s. Found Bytes size: %s".formatted(MAX_ARRAY_SIZE, size)); + } + byte[] bytes = new byte[(int) size]; + decoder.readFixed(bytes); + VARBINARY.writeSlice(builder, Slices.wrappedBuffer(bytes)); + } + } + + private static class FixedBlockBuildingDecoder + extends BlockBuildingDecoder + { + private final int expectedSize; + + public FixedBlockBuildingDecoder(int expectedSize) + { + verify(expectedSize >= 0, "expected size must be greater than or equal to 0"); + this.expectedSize = expectedSize; + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + byte[] slice = new byte[expectedSize]; + decoder.readFixed(slice); + VARBINARY.writeSlice(builder, Slices.wrappedBuffer(slice)); + } + } + + private static class EnumBlockBuildingDecoder + extends BlockBuildingDecoder + { + private Slice[] symbols; + + public EnumBlockBuildingDecoder(Resolver.EnumAdjust action) + throws AvroTypeException + { + List symbolsList = requireNonNull(action, "action is null").reader.getEnumSymbols(); + symbols = symbolsList.stream().map(Slices::utf8Slice).toArray(Slice[]::new); + if (!action.noAdjustmentsNeeded) { + Slice[] adjustedSymbols = new Slice[(action.writer.getEnumSymbols().size())]; + for (int i = 0; i < action.adjustments.length; i++) { + if (action.adjustments[i] < 0) { + throw new AvroTypeException("No reader Enum value for writer Enum value " + action.writer.getEnumSymbols().get(i)); + } + adjustedSymbols[i] = symbols[action.adjustments[i]]; + } + symbols = adjustedSymbols; + } + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + VARCHAR.writeSlice(builder, symbols[decoder.readEnum()]); + } + } + + private static class ArrayBlockBuildingDecoder + extends BlockBuildingDecoder + { + private final BlockBuildingDecoder elementBlockBuildingDecoder; + + public ArrayBlockBuildingDecoder(Resolver.Container containerAction, AvroTypeManager typeManager) + throws AvroTypeException + { + requireNonNull(containerAction, "containerAction is null"); + verify(containerAction.reader.getType() == Schema.Type.ARRAY, "Reader schema must be a array"); + elementBlockBuildingDecoder = createBlockBuildingDecoderForAction(containerAction.elementAction, typeManager); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BlockBuilder elementBuilder = builder.beginBlockEntry(); + long elementsInBlock = decoder.readArrayStart(); + if (elementsInBlock > 0) { + do { + for (int i = 0; i < elementsInBlock; i++) { + elementBlockBuildingDecoder.decodeIntoBlock(decoder, elementBuilder); + } + } + while ((elementsInBlock = decoder.arrayNext()) > 0); + } + builder.closeEntry(); + } + } + + private static class MapBlockBuildingDecoder + extends BlockBuildingDecoder + { + private final BlockBuildingDecoder keyBlockBuildingDecoder = new StringBlockBuildingDecoder(); + private final BlockBuildingDecoder valueBlockBuildingDecoder; + + public MapBlockBuildingDecoder(Resolver.Container containerAction, AvroTypeManager typeManager) + throws AvroTypeException + { + requireNonNull(containerAction, "containerAction is null"); + verify(containerAction.reader.getType() == Schema.Type.MAP, "Reader schema must be a map"); + valueBlockBuildingDecoder = createBlockBuildingDecoderForAction(containerAction.elementAction, typeManager); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + BlockBuilder entryBuilder = builder.beginBlockEntry(); + long entriesInBlock = decoder.readMapStart(); + // TODO need to filter out all but last value for key? + if (entriesInBlock > 0) { + do { + for (int i = 0; i < entriesInBlock; i++) { + keyBlockBuildingDecoder.decodeIntoBlock(decoder, entryBuilder); + valueBlockBuildingDecoder.decodeIntoBlock(decoder, entryBuilder); + } + } + while ((entriesInBlock = decoder.mapNext()) > 0); + } + builder.closeEntry(); + } + } + + private static class RowBlockBuildingDecoder + extends BlockBuildingDecoder + { + private final RowBuildingAction[] buildSteps; + + private RowBlockBuildingDecoder(Schema writeSchema, Schema readSchema, AvroTypeManager typeManager) + throws AvroTypeException + { + this(Resolver.resolve(writeSchema, readSchema, new GenericData()), typeManager); + } + + private RowBlockBuildingDecoder(Resolver.Action action, AvroTypeManager typeManager) + throws AvroTypeException + + { + if (action instanceof Resolver.ErrorAction errorAction) { + throw new AvroTypeException("Error in resolution of types for row building: " + errorAction.error); + } + if (!(action instanceof Resolver.RecordAdjust recordAdjust)) { + throw new AvroTypeException("Write and Read Schemas must be records when building a row block building decoder. Illegal action: " + action); + } + buildSteps = new RowBuildingAction[recordAdjust.fieldActions.length + recordAdjust.readerOrder.length + - recordAdjust.firstDefault]; + int i = 0; + int readerFieldCount = 0; + for (; i < recordAdjust.fieldActions.length; i++) { + Resolver.Action fieldAction = recordAdjust.fieldActions[i]; + if (fieldAction instanceof Resolver.Skip skip) { + buildSteps[i] = new SkipSchemaBuildingAction(skip.writer); + } + else { + Schema.Field readField = recordAdjust.readerOrder[readerFieldCount++]; + buildSteps[i] = new BuildIntoBlockAction(createBlockBuildingDecoderForAction(fieldAction, typeManager), readField.pos()); + } + } + + // add defaulting if required + for (; i < buildSteps.length; i++) { + // create constant block + Schema.Field readField = recordAdjust.readerOrder[readerFieldCount++]; + // TODO see if it can be done with RLE block + buildSteps[i] = new ConstantBlockAction(getDefaultBlockBuilder(readField, typeManager), readField.pos()); + } + + verify(Arrays.stream(buildSteps) + .mapToInt(RowBuildingAction::getOutputChannel) + .filter(a -> a >= 0) + .distinct() + .sum() == (recordAdjust.reader.getFields().size() * (recordAdjust.reader.getFields().size() - 1) / 2), + "Every channel in output block builder must be accounted for"); + verify(Arrays.stream(buildSteps) + .mapToInt(RowBuildingAction::getOutputChannel) + .filter(a -> a >= 0) + .distinct().count() == (long) recordAdjust.reader.getFields().size(), "Every channel in output block builder must be accounted for"); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + SingleRowBlockWriter currentBuilder = (SingleRowBlockWriter) builder.beginBlockEntry(); + decodeIntoBlockProvided(decoder, currentBuilder::getFieldBlockBuilder); + builder.closeEntry(); + } + + protected void decodeIntoPageBuilder(Decoder decoder, PageBuilder builder) + throws IOException + { + builder.declarePosition(); + decodeIntoBlockProvided(decoder, builder::getBlockBuilder); + } + + protected void decodeIntoBlockProvided(Decoder decoder, IntFunction fieldBlockBuilder) + throws IOException + { + for (RowBuildingAction buildStep : buildSteps) { + // TODO replace with switch sealed class syntax when stable + if (buildStep instanceof SkipSchemaBuildingAction skipSchemaBuildingAction) { + skipSchemaBuildingAction.skip(decoder); + } + else if (buildStep instanceof BuildIntoBlockAction buildIntoBlockAction) { + buildIntoBlockAction.decode(decoder, fieldBlockBuilder); + } + else if (buildStep instanceof ConstantBlockAction constantBlockAction) { + constantBlockAction.addConstant(fieldBlockBuilder); + } + else { + throw new IllegalStateException("Unhandled buildingAction"); + } + } + } + + sealed interface RowBuildingAction + permits BuildIntoBlockAction, ConstantBlockAction, SkipSchemaBuildingAction + { + int getOutputChannel(); + } + + private static final class BuildIntoBlockAction + implements RowBuildingAction + { + private final BlockBuildingDecoder delegate; + private final int outputChannel; + + public BuildIntoBlockAction(BlockBuildingDecoder delegate, int outputChannel) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + checkArgument(outputChannel >= 0, "outputChannel must be positive"); + this.outputChannel = outputChannel; + } + + public void decode(Decoder decoder, IntFunction channelSelector) + throws IOException + { + delegate.decodeIntoBlock(decoder, channelSelector.apply(outputChannel)); + } + + @Override + public int getOutputChannel() + { + return outputChannel; + } + } + + protected static final class ConstantBlockAction + implements RowBuildingAction + { + private final IoConsumer addConstantFunction; + private final int outputChannel; + + public ConstantBlockAction(IoConsumer addConstantFunction, int outputChannel) + { + this.addConstantFunction = requireNonNull(addConstantFunction, "addConstantFunction is null"); + checkArgument(outputChannel >= 0, "outputChannel must be positive"); + this.outputChannel = outputChannel; + } + + public void addConstant(IntFunction channelSelector) + throws IOException + { + addConstantFunction.accept(channelSelector.apply(outputChannel)); + } + + @Override + public int getOutputChannel() + { + return outputChannel; + } + } + + private static final class SkipSchemaBuildingAction + implements RowBuildingAction + { + private final SkipAction skipAction; + + SkipSchemaBuildingAction(Schema schema) + { + skipAction = createSkipActionForSchema(requireNonNull(schema, "schema is null")); + } + + public void skip(Decoder decoder) + throws IOException + { + skipAction.skip(decoder); + } + + @Override + public int getOutputChannel() + { + return -1; + } + + @FunctionalInterface + private interface SkipAction + { + void skip(Decoder decoder) + throws IOException; + } + + private static SkipAction createSkipActionForSchema(Schema schema) + { + return switch (schema.getType()) { + case NULL -> Decoder::readNull; + case BOOLEAN -> Decoder::readBoolean; + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + case FLOAT -> Decoder::readFloat; + case DOUBLE -> Decoder::readDouble; + case STRING -> Decoder::skipString; + case BYTES -> Decoder::skipBytes; + case ENUM -> Decoder::readEnum; + case FIXED -> { + int size = schema.getFixedSize(); + yield decoder -> decoder.skipFixed(size); + } + case ARRAY -> new ArraySkipAction(schema.getElementType()); + case MAP -> new MapSkipAction(schema.getValueType()); + case RECORD -> new RecordSkipAction(schema.getFields()); + case UNION -> new UnionSkipAction(schema.getTypes()); + }; + } + + private static class ArraySkipAction + implements SkipAction + { + private final SkipAction elementSkipAction; + + public ArraySkipAction(Schema elementSchema) + { + elementSkipAction = createSkipActionForSchema(requireNonNull(elementSchema, "elementSchema is null")); + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (long i = decoder.skipArray(); i != 0; i = decoder.skipArray()) { + for (long j = 0; j < i; j++) { + elementSkipAction.skip(decoder); + } + } + } + } + + private static class MapSkipAction + implements SkipAction + { + private final SkipAction valueSkipAction; + + public MapSkipAction(Schema valueSchema) + { + valueSkipAction = createSkipActionForSchema(requireNonNull(valueSchema, "valueSchema is null")); + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (long i = decoder.skipMap(); i != 0; i = decoder.skipMap()) { + for (long j = 0; j < i; j++) { + decoder.skipString(); // key + valueSkipAction.skip(decoder); // value + } + } + } + } + + private static class RecordSkipAction + implements SkipAction + { + private final SkipAction[] fieldSkips; + + public RecordSkipAction(List fields) + { + fieldSkips = new SkipAction[requireNonNull(fields, "fields is null").size()]; + for (int i = 0; i < fields.size(); i++) { + fieldSkips[i] = createSkipActionForSchema(fields.get(i).schema()); + } + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + for (SkipAction fieldSkipAction : fieldSkips) { + fieldSkipAction.skip(decoder); + } + } + } + + private static class UnionSkipAction + implements SkipAction + { + private final SkipAction[] skipActions; + + private UnionSkipAction(List types) + { + skipActions = new SkipAction[requireNonNull(types, "types is null").size()]; + for (int i = 0; i < types.size(); i++) { + skipActions[i] = createSkipActionForSchema(types.get(i)); + } + } + + @Override + public void skip(Decoder decoder) + throws IOException + { + skipActions[decoder.readIndex()].skip(decoder); + } + } + } + } + + private static class WriterUnionBlockBuildingDecoder + extends BlockBuildingDecoder + { + protected final BlockBuildingDecoder[] blockBuildingDecoders; + + public WriterUnionBlockBuildingDecoder(Resolver.WriterUnion writerUnion, AvroTypeManager typeManager) + throws AvroTypeException + { + requireNonNull(writerUnion, "writerUnion is null"); + blockBuildingDecoders = new BlockBuildingDecoder[writerUnion.actions.length]; + for (int i = 0; i < writerUnion.actions.length; i++) { + blockBuildingDecoders[i] = createBlockBuildingDecoderForAction(writerUnion.actions[i], typeManager); + } + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + decodeIntoBlock(decoder.readIndex(), decoder, builder); + } + + protected void decodeIntoBlock(int blockBuilderIndex, Decoder decoder, BlockBuilder builder) + throws IOException + { + blockBuildingDecoders[blockBuilderIndex].decodeIntoBlock(decoder, builder); + } + } + + private static class WriterUnionCoercedIntoRowBlockBuildingDecoder + extends WriterUnionBlockBuildingDecoder + { + private final boolean readUnionEquiv; + private final int[] indexToChannel; + private final int totalChannels; + + public WriterUnionCoercedIntoRowBlockBuildingDecoder(Resolver.WriterUnion writerUnion, AvroTypeManager avroTypeManager) + throws AvroTypeException + { + super(writerUnion, avroTypeManager); + readUnionEquiv = writerUnion.unionEquiv; + List readSchemas = writerUnion.reader.getTypes(); + checkArgument(readSchemas.size() == writerUnion.actions.length, "each read schema must have resolvedAction For it"); + indexToChannel = getIndexToChannel(readSchemas); + totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + int index = decoder.readIndex(); + if (readUnionEquiv) { + // if no output channel then the schema is null and the whole record can be null; + if (indexToChannel[index] < 0) { + NullBlockBuildingDecoder.INSTANCE.decodeIntoBlock(decoder, builder); + } + else { + // the index for the reader and writer are the same, so the channel for the index is used to select the field to populate + makeSingleRowWithTagAndAllFieldsNullButOne(indexToChannel[index], totalChannels, blockBuildingDecoders[index], decoder, builder); + } + } + else { + // delegate to ReaderUnionCoercedIntoRowBlockBuildingDecoder to get the output channel from the resolved action + decodeIntoBlock(index, decoder, builder); + } + } + + protected static void makeSingleRowWithTagAndAllFieldsNullButOne(int outputChannel, int totalChannels, BlockBuildingDecoder blockBuildingDecoder, Decoder decoder, BlockBuilder builder) + throws IOException + { + SingleRowBlockWriter currentBuilder = (SingleRowBlockWriter) builder.beginBlockEntry(); + //add tag with channel + UNION_FIELD_TAG_TYPE.writeLong(currentBuilder.getFieldBlockBuilder(0), outputChannel); + //add in null fields except one + for (int channel = 1; channel <= totalChannels; channel++) { + if (channel == outputChannel + 1) { + blockBuildingDecoder.decodeIntoBlock(decoder, currentBuilder.getFieldBlockBuilder(channel)); + } + else { + currentBuilder.getFieldBlockBuilder(channel).appendNull(); + } + } + builder.closeEntry(); + } + + protected static int[] getIndexToChannel(List schemas) + { + int[] indexToChannel = new int[schemas.size()]; + int outputChannel = 0; + for (int i = 0; i < indexToChannel.length; i++) { + if (schemas.get(i).getType() == Schema.Type.NULL) { + indexToChannel[i] = -1; + } + else { + indexToChannel[i] = outputChannel++; + } + } + return indexToChannel; + } + } + + private static class ReaderUnionCoercedIntoRowBlockBuildingDecoder + extends + BlockBuildingDecoder + { + private final BlockBuildingDecoder delegateBuilder; + private final int outputChannel; + private final int totalChannels; + + public ReaderUnionCoercedIntoRowBlockBuildingDecoder(Resolver.ReaderUnion readerUnion, AvroTypeManager avroTypeManager) + throws AvroTypeException + { + requireNonNull(readerUnion, "readerUnion is null"); + requireNonNull(avroTypeManager, "avroTypeManger is null"); + int[] indexToChannel = WriterUnionCoercedIntoRowBlockBuildingDecoder.getIndexToChannel(readerUnion.reader.getTypes()); + outputChannel = indexToChannel[readerUnion.firstMatch]; + delegateBuilder = createBlockBuildingDecoderForAction(readerUnion.actualAction, avroTypeManager); + totalChannels = (int) IntStream.of(indexToChannel).filter(i -> i >= 0).count(); + } + + @Override + protected void decodeIntoBlock(Decoder decoder, BlockBuilder builder) + throws IOException + { + if (outputChannel < 0) { + // No outputChannel for Null schema in union, null out coerces struct + NullBlockBuildingDecoder.INSTANCE.decodeIntoBlock(decoder, builder); + } + else { + WriterUnionCoercedIntoRowBlockBuildingDecoder + .makeSingleRowWithTagAndAllFieldsNullButOne(outputChannel, totalChannels, delegateBuilder, decoder, builder); + } + } + } + + private static LongIoFunction getLongPromotionFunction(Schema writerSchema) + throws AvroTypeException + { + if (writerSchema.getType() == Schema.Type.INT) { + return Decoder::readInt; + } + throw new AvroTypeException("Cannot promote type %s to long".formatted(writerSchema.getType())); + } + + private static FloatIoFunction getFloatPromotionFunction(Schema writerSchema) + throws AvroTypeException + { + return switch (writerSchema.getType()) { + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + default -> throw new AvroTypeException("Cannot promote type %s to float".formatted(writerSchema.getType())); + }; + } + + private static DoubleIoFunction getDoublePromotionFunction(Schema writerSchema) + throws AvroTypeException + { + return switch (writerSchema.getType()) { + case INT -> Decoder::readInt; + case LONG -> Decoder::readLong; + case FLOAT -> Decoder::readFloat; + default -> throw new AvroTypeException("Cannot promote type %s to double".formatted(writerSchema.getType())); + }; + } + + @FunctionalInterface + private interface LongIoFunction + { + long apply(A a) + throws IOException; + } + + @FunctionalInterface + private interface FloatIoFunction + { + float apply(A a) + throws IOException; + } + + @FunctionalInterface + private interface DoubleIoFunction + { + double apply(A a) + throws IOException; + } + + @FunctionalInterface + private interface IoConsumer + { + void accept(A a) + throws IOException; + } + + // Avro supports default values for reader record fields that are missing in the writer schema + // the bytes representing the default field value are passed to a block building decoder + // so that it can pack the block appropriately for the default type. + private static IoConsumer getDefaultBlockBuilder(Schema.Field field, AvroTypeManager typeManager) + throws AvroTypeException + { + BlockBuildingDecoder buildingDecoder = createBlockBuildingDecoderForAction(Resolver.resolve(field.schema(), field.schema()), typeManager); + byte[] defaultBytes = getDefaultByes(field); + BinaryDecoder reuse = DecoderFactory.get().binaryDecoder(defaultBytes, null); + return blockBuilder -> buildingDecoder.decodeIntoBlock(DecoderFactory.get().binaryDecoder(defaultBytes, reuse), blockBuilder); + } + + private static byte[] getDefaultByes(Schema.Field field) + throws AvroTypeException + { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Encoder e = EncoderFactory.get().binaryEncoder(out, null); + ResolvingGrammarGenerator.encode(e, field.schema(), Accessor.defaultValue(field)); + e.flush(); + return out.toByteArray(); + } + catch (IOException exception) { + throw new AvroTypeException("Unable to encode to bytes for default value in field " + field, exception); + } + } + + /** + * Used for throwing {@link AvroTypeException} through interfaces that can not throw checked exceptions like DatumReader + */ + protected static class UncheckedAvroTypeException + extends RuntimeException + { + private final AvroTypeException avroTypeException; + + public UncheckedAvroTypeException(AvroTypeException cause) + { + super(requireNonNull(cause, "cause is null")); + avroTypeException = cause; + } + + public AvroTypeException getAvroTypeException() + { + return avroTypeException; + } + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeException.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeException.java new file mode 100644 index 000000000000..e2c1ca12866c --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeException.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +public class AvroTypeException + extends Exception +{ + public AvroTypeException(org.apache.avro.AvroTypeException runtimeAvroTypeException) + { + super(runtimeAvroTypeException); + } + + public AvroTypeException(String message) + { + super(message); + } + + public AvroTypeException(String message, Throwable cause) + { + super(message, cause); + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java new file mode 100644 index 000000000000..80463aca5c8e --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeManager.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; + +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; + +public interface AvroTypeManager +{ + /** + * Called when the type manager is reading out data from a data file such as in {@link AvroFileReader} + * + * @param fileMetadata metadata from the file header + */ + void configure(Map fileMetadata); + + Optional overrideTypeForSchema(Schema schema) + throws AvroTypeException; + + /** + * Object provided by FasterReader's deserialization with no conversions. + * Object class determined by Avro's standard generic data process + * BlockBuilder provided by Type returned above for the schema + * Possible to override for each primitive type as well. + */ + Optional> overrideBuildingFunctionForSchema(Schema schema) + throws AvroTypeException; +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java new file mode 100644 index 000000000000..fe805ae8b10b --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroTypeUtils.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.apache.avro.Schema; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.trino.hive.formats.UnionToRowCoercionUtils.rowTypeForUnionOfTypes; +import static java.util.function.Predicate.not; + +public final class AvroTypeUtils +{ + private AvroTypeUtils() {} + + public static Type typeFromAvro(Schema schema, AvroTypeManager avroTypeManager) + throws AvroTypeException + { + return typeFromAvro(schema, avroTypeManager, new HashSet<>()); + } + + private static Type typeFromAvro(final Schema schema, AvroTypeManager avroTypeManager, Set enclosingRecords) + throws AvroTypeException + { + Optional customType = avroTypeManager.overrideTypeForSchema(schema); + if (customType.isPresent()) { + return customType.get(); + } + return switch (schema.getType()) { + case NULL -> throw new UnsupportedOperationException("No null column type support"); + case BOOLEAN -> BooleanType.BOOLEAN; + case INT -> IntegerType.INTEGER; + case LONG -> BigintType.BIGINT; + case FLOAT -> RealType.REAL; + case DOUBLE -> DoubleType.DOUBLE; + case ENUM, STRING -> VarcharType.VARCHAR; + case FIXED, BYTES -> VarbinaryType.VARBINARY; + case ARRAY -> new ArrayType(typeFromAvro(schema.getElementType(), avroTypeManager, enclosingRecords)); + case MAP -> new MapType(VarcharType.VARCHAR, typeFromAvro(schema.getValueType(), avroTypeManager, enclosingRecords), new TypeOperators()); + case RECORD -> { + if (!enclosingRecords.add(schema)) { + throw new UnsupportedOperationException("Unable to represent recursive avro schemas in Trino Type form"); + } + ImmutableList.Builder rowFieldTypes = ImmutableList.builder(); + for (Schema.Field field : schema.getFields()) { + rowFieldTypes.add(new RowType.Field(Optional.of(field.name()), typeFromAvro(field.schema(), avroTypeManager, new HashSet<>(enclosingRecords)))); + } + yield RowType.from(rowFieldTypes.build()); + } + case UNION -> { + if (isSimpleNullableUnion(schema)) { + yield typeFromAvro(unwrapNullableUnion(schema), avroTypeManager, enclosingRecords); + } + else { + yield rowTypeForUnion(schema, avroTypeManager, enclosingRecords); + } + } + }; + } + + static boolean isSimpleNullableUnion(Schema schema) + { + verify(schema.isUnion(), "Schema must be union"); + return schema.getTypes().stream().filter(not(Schema::isNullable)).count() == 1L; + } + + private static Schema unwrapNullableUnion(Schema schema) + { + verify(schema.isUnion(), "Schema must be union"); + verify(schema.isNullable() && schema.getTypes().size() == 2); + return schema.getTypes().stream().filter(not(Schema::isNullable)).collect(onlyElement()); + } + + private static RowType rowTypeForUnion(Schema schema, AvroTypeManager avroTypeManager, Set enclosingRecords) + throws AvroTypeException + { + verify(schema.isUnion()); + ImmutableList.Builder unionTypes = ImmutableList.builder(); + for (Schema variant : schema.getTypes()) { + if (!variant.isNullable()) { + unionTypes.add(typeFromAvro(variant, avroTypeManager, enclosingRecords)); + } + } + return rowTypeForUnionOfTypes(unionTypes.build()); + } +} diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java new file mode 100644 index 000000000000..b41353c135a0 --- /dev/null +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/NativeLogicalTypesAvroTypeManager.java @@ -0,0 +1,341 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.airlift.log.Logger; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Timestamps; +import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericFixed; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import static io.trino.spi.type.UuidType.javaUuidToTrinoUuid; +import static java.util.Objects.requireNonNull; +import static org.apache.avro.LogicalTypes.fromSchemaIgnoreInvalid; + +/** + * An implementation that translates Avro Standard Logical types into Trino SPI types + */ +public class NativeLogicalTypesAvroTypeManager + implements AvroTypeManager +{ + private static final Logger log = Logger.get(NativeLogicalTypesAvroTypeManager.class); + + public static final Schema TIMESTAMP_MILLIS_SCHEMA; + public static final Schema TIMESTAMP_MICROS_SCHEMA; + public static final Schema DATE_SCHEMA; + public static final Schema TIME_MILLIS_SCHEMA; + public static final Schema TIME_MICROS_SCHEMA; + public static final Schema UUID_SCHEMA; + + // Copied from org.apache.avro.LogicalTypes + protected static final String DECIMAL = "decimal"; + protected static final String UUID = "uuid"; + protected static final String DATE = "date"; + protected static final String TIME_MILLIS = "time-millis"; + protected static final String TIME_MICROS = "time-micros"; + protected static final String TIMESTAMP_MILLIS = "timestamp-millis"; + protected static final String TIMESTAMP_MICROS = "timestamp-micros"; + protected static final String LOCAL_TIMESTAMP_MILLIS = "local-timestamp-millis"; + protected static final String LOCAL_TIMESTAMP_MICROS = "local-timestamp-micros"; + + static { + TIMESTAMP_MILLIS_SCHEMA = SchemaBuilder.builder().longType(); + LogicalTypes.timestampMillis().addToSchema(TIMESTAMP_MILLIS_SCHEMA); + TIMESTAMP_MICROS_SCHEMA = SchemaBuilder.builder().longType(); + LogicalTypes.timestampMicros().addToSchema(TIMESTAMP_MICROS_SCHEMA); + DATE_SCHEMA = Schema.create(Schema.Type.INT); + LogicalTypes.date().addToSchema(DATE_SCHEMA); + TIME_MILLIS_SCHEMA = Schema.create(Schema.Type.INT); + LogicalTypes.timeMillis().addToSchema(TIME_MILLIS_SCHEMA); + TIME_MICROS_SCHEMA = Schema.create(Schema.Type.LONG); + LogicalTypes.timeMicros().addToSchema(TIME_MICROS_SCHEMA); + UUID_SCHEMA = Schema.create(Schema.Type.STRING); + LogicalTypes.uuid().addToSchema(UUID_SCHEMA); + } + + @Override + public void configure(Map fileMetadata) {} + + @Override + public Optional overrideTypeForSchema(Schema schema) + throws AvroTypeException + { + return validateAndProduceFromName(schema, NativeLogicalTypesAvroTypeManager::getAvroLogicalTypeSpiType); + } + + @Override + public Optional> overrideBuildingFunctionForSchema(Schema schema) + throws AvroTypeException + { + return validateAndProduceFromName(schema, getLogicalTypeBuildingFunction(schema)); + } + + private static Type getAvroLogicalTypeSpiType(LogicalType logicalType) + { + return switch (logicalType.getName()) { + case TIMESTAMP_MILLIS -> TimestampType.TIMESTAMP_MILLIS; + case TIMESTAMP_MICROS -> TimestampType.TIMESTAMP_MICROS; + case DECIMAL -> { + LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; + yield DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale()); + } + case DATE -> DateType.DATE; + case TIME_MILLIS -> TimeType.TIME_MILLIS; + case TIME_MICROS -> TimeType.TIME_MICROS; + case UUID -> UuidType.UUID; + default -> throw new IllegalStateException("Unreachable unfiltered logical type"); + }; + } + + private static Function> getLogicalTypeBuildingFunction(Schema schema) + { + return logicalType -> switch (logicalType.getName()) { + case TIMESTAMP_MILLIS -> { + if (schema.getType() == Schema.Type.LONG) { + yield (builder, obj) -> { + Long l = (Long) obj; + TimestampType.TIMESTAMP_MILLIS.writeLong(builder, l * Timestamps.MICROSECONDS_PER_MILLISECOND); + }; + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TIMESTAMP_MICROS -> { + if (schema.getType() == Schema.Type.LONG) { + yield (builder, obj) -> { + Long l = (Long) obj; + TimestampType.TIMESTAMP_MICROS.writeLong(builder, l); + }; + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case DECIMAL -> { + LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType; + DecimalType decimalType = DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale()); + Function byteExtract = switch (schema.getType()) { + case BYTES -> // This is only safe because we don't reuse byte buffer objects which means each gets sized exactly for the bytes contained + (obj) -> ((ByteBuffer) obj).array(); + case FIXED -> (obj) -> ((GenericFixed) obj).bytes(); + default -> throw new IllegalStateException("Unreachable unfiltered logical type"); + }; + if (decimalType.isShort()) { + yield (builder, obj) -> decimalType.writeLong(builder, fromBigEndian(byteExtract.apply(obj))); + } + else { + yield (builder, obj) -> decimalType.writeObject(builder, Int128.fromBigEndian(byteExtract.apply(obj))); + } + } + case DATE -> { + if (schema.getType() == Schema.Type.INT) { + yield (builder, obj) -> { + Integer i = (Integer) obj; + DateType.DATE.writeLong(builder, i.longValue()); + }; + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TIME_MILLIS -> { + if (schema.getType() == Schema.Type.INT) { + yield (builder, obj) -> { + Integer i = (Integer) obj; + TimeType.TIME_MILLIS.writeLong(builder, i.longValue() * Timestamps.PICOSECONDS_PER_MILLISECOND); + }; + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case TIME_MICROS -> { + if (schema.getType() == Schema.Type.LONG) { + yield (builder, obj) -> { + Long i = (Long) obj; + TimeType.TIME_MICROS.writeLong(builder, i * Timestamps.PICOSECONDS_PER_MICROSECOND); + }; + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + case UUID -> { + if (schema.getType() == Schema.Type.STRING) { + yield (builder, obj) -> UuidType.UUID.writeSlice(builder, javaUuidToTrinoUuid(java.util.UUID.fromString(obj.toString()))); + } + throw new IllegalStateException("Unreachable unfiltered logical type"); + } + default -> throw new IllegalStateException("Unreachable unfiltered logical type"); + }; + } + + private Optional validateAndProduceFromName(Schema schema, Function produce) + { + // TODO replace with switch sealed class syntax when stable + ValidateLogicalTypeResult logicalTypeResult = validateLogicalType(schema); + if (logicalTypeResult instanceof NoLogicalType ignored) { + return Optional.empty(); + } + if (logicalTypeResult instanceof NonNativeAvroLogicalType ignored) { + log.debug("Unrecognized logical type " + schema); + return Optional.empty(); + } + if (logicalTypeResult instanceof InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { + log.debug(invalidNativeAvroLogicalType.getCause(), "Invalidly configured native avro logical type"); + return Optional.empty(); + } + if (logicalTypeResult instanceof ValidNativeAvroLogicalType validNativeAvroLogicalType) { + return Optional.of(produce.apply(validNativeAvroLogicalType.getLogicalType())); + } + throw new IllegalStateException("Unhandled validate logical type result"); + } + + protected static ValidateLogicalTypeResult validateLogicalType(Schema schema) + { + final String typeName = schema.getProp(LogicalType.LOGICAL_TYPE_PROP); + if (typeName == null) { + return new NoLogicalType(); + } + LogicalType logicalType; + switch (typeName) { + case TIMESTAMP_MILLIS, TIMESTAMP_MICROS, DECIMAL, DATE, TIME_MILLIS, TIME_MICROS, UUID: + logicalType = fromSchemaIgnoreInvalid(schema); + break; + case LOCAL_TIMESTAMP_MICROS + LOCAL_TIMESTAMP_MILLIS: + log.debug("Logical type " + typeName + " not currently supported by by Trino"); + // fall through + default: + return new NonNativeAvroLogicalType(typeName); + } + // make sure the type is valid before returning it + if (logicalType != null) { + try { + logicalType.validate(schema); + } + catch (RuntimeException e) { + return new InvalidNativeAvroLogicalType(typeName, e); + } + return new ValidNativeAvroLogicalType(logicalType); + } + else { + return new NonNativeAvroLogicalType(typeName); + } + } + + protected abstract static sealed class ValidateLogicalTypeResult + permits NoLogicalType, NonNativeAvroLogicalType, InvalidNativeAvroLogicalType, ValidNativeAvroLogicalType {} + + protected static final class NoLogicalType + extends ValidateLogicalTypeResult {} + + protected static final class NonNativeAvroLogicalType + extends ValidateLogicalTypeResult + { + private final String logicalTypeName; + + public NonNativeAvroLogicalType(String logicalTypeName) + { + this.logicalTypeName = requireNonNull(logicalTypeName, "logicalTypeName is null"); + } + + public String getLogicalTypeName() + { + return logicalTypeName; + } + } + + protected static final class InvalidNativeAvroLogicalType + extends ValidateLogicalTypeResult + { + private final String logicalTypeName; + private final RuntimeException cause; + + public InvalidNativeAvroLogicalType(String logicalTypeName, RuntimeException cause) + { + this.logicalTypeName = requireNonNull(logicalTypeName, "logicalTypeName"); + this.cause = requireNonNull(cause, "cause is null"); + } + + public String getLogicalTypeName() + { + return logicalTypeName; + } + + public RuntimeException getCause() + { + return cause; + } + } + + protected static final class ValidNativeAvroLogicalType + extends ValidateLogicalTypeResult + { + private final LogicalType logicalType; + + public ValidNativeAvroLogicalType(LogicalType logicalType) + { + this.logicalType = requireNonNull(logicalType, "logicalType is null"); + } + + public LogicalType getLogicalType() + { + return logicalType; + } + } + + private static final VarHandle BIG_ENDIAN_LONG_VIEW = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); + + /** + * Decode a long from the two's complement big-endian representation. + * + * @param bytes the two's complement big-endian encoding of the number. It must contain at least 1 byte. + * It may contain more than 8 bytes if the leading bytes are not significant (either zeros or -1) + * @throws ArithmeticException if the bytes represent a number outside the range [-2^63, 2^63 - 1] + */ + // Styled from io.trino.spi.type.Int128.fromBigEndian + public static long fromBigEndian(byte[] bytes) + { + if (bytes.length > 8) { + int offset = bytes.length - Long.BYTES; + long res = (long) BIG_ENDIAN_LONG_VIEW.get(bytes, offset); + // verify that the significant bits above 64 bits are proper sign extension + int expectedSignExtensionByte = (int) (res >> 63); + for (int i = 0; i < offset; i++) { + if (bytes[i] != expectedSignExtensionByte) { + throw new ArithmeticException("Overflow"); + } + } + return res; + } + if (bytes.length == 8) { + return (long) BIG_ENDIAN_LONG_VIEW.get(bytes, 0); + } + long res = (bytes[0] >> 7); + for (byte b : bytes) { + res = (res << 8) | (b & 0xFF); + } + return res; + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java new file mode 100644 index 000000000000..f21b73ace98e --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/NoOpAvroTypeManager.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; +import org.apache.avro.Schema; + +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; + +public class NoOpAvroTypeManager + implements AvroTypeManager +{ + public static final NoOpAvroTypeManager INSTANCE = new NoOpAvroTypeManager(); + + private NoOpAvroTypeManager() {} + + @Override + public void configure(Map fileMetadata) {} + + @Override + public Optional overrideTypeForSchema(Schema schema) + throws AvroTypeException + { + return Optional.empty(); + } + + @Override + public Optional> overrideBuildingFunctionForSchema(Schema schema) + throws AvroTypeException + { + return Optional.empty(); + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java new file mode 100644 index 000000000000..4e94ff369140 --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithAvroNativeTypeManagement.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Longs; +import io.trino.filesystem.TrinoInputFile; +import io.trino.spi.Page; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.SqlDate; +import io.trino.spi.type.SqlDecimal; +import io.trino.spi.type.SqlTime; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Type; +import io.trino.spi.type.UuidType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.Date; +import java.util.UUID; + +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.DATE_SCHEMA; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.TIMESTAMP_MICROS_SCHEMA; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.TIMESTAMP_MILLIS_SCHEMA; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.TIME_MICROS_SCHEMA; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.TIME_MILLIS_SCHEMA; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.UUID_SCHEMA; +import static io.trino.hive.formats.avro.TestAvroPageDataReaderWithoutTypeManager.createWrittenFileWithData; +import static io.trino.hive.formats.avro.TestAvroPageDataReaderWithoutTypeManager.createWrittenFileWithSchema; +import static io.trino.hive.formats.avro.TestLongFromBigEndian.padBigEndianCorrectly; +import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +public class TestAvroPageDataReaderWithAvroNativeTypeManagement +{ + private static final Schema DECIMAL_SMALL_BYTES_SCHEMA; + private static final int SMALL_FIXED_SIZE = 8; + private static final int LARGE_FIXED_SIZE = 9; + private static final Schema DECIMAL_SMALL_FIXED_SCHEMA; + private static final Schema DECIMAL_LARGE_BYTES_SCHEMA; + private static final Schema DECIMAL_LARGE_FIXED_SCHEMA; + private static final Date testTime = new Date(780681600000L); + private static final Type SMALL_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION - 1, 2); + private static final Type LARGE_DECIMAL_TYPE = DecimalType.createDecimalType(MAX_SHORT_PRECISION + 1, 2); + + static { + LogicalTypes.Decimal small = LogicalTypes.decimal(MAX_SHORT_PRECISION - 1, 2); + LogicalTypes.Decimal large = LogicalTypes.decimal(MAX_SHORT_PRECISION + 1, 2); + DECIMAL_SMALL_BYTES_SCHEMA = Schema.create(Schema.Type.BYTES); + small.addToSchema(DECIMAL_SMALL_BYTES_SCHEMA); + DECIMAL_SMALL_FIXED_SCHEMA = Schema.createFixed("smallDecimal", "myFixed", "namespace", SMALL_FIXED_SIZE); + small.addToSchema(DECIMAL_SMALL_FIXED_SCHEMA); + DECIMAL_LARGE_BYTES_SCHEMA = Schema.create(Schema.Type.BYTES); + large.addToSchema(DECIMAL_LARGE_BYTES_SCHEMA); + DECIMAL_LARGE_FIXED_SCHEMA = Schema.createFixed("largeDecimal", "myFixed", "namespace", (int) ((MAX_SHORT_PRECISION + 2) * Math.log(10) / Math.log(2) / 8) + 1); + large.addToSchema(DECIMAL_LARGE_FIXED_SCHEMA); + } + + @Test + public void testTypesSimple() + throws IOException, AvroTypeException + { + Schema schema = SchemaBuilder.builder() + .record("allSupported") + .fields() + .name("timestampMillis") + .type(TIMESTAMP_MILLIS_SCHEMA).noDefault() + .name("timestampMicros") + .type(TIMESTAMP_MICROS_SCHEMA).noDefault() + .name("smallBytesDecimal") + .type(DECIMAL_SMALL_BYTES_SCHEMA).noDefault() + .name("smallFixedDecimal") + .type(DECIMAL_SMALL_FIXED_SCHEMA).noDefault() + .name("largeBytesDecimal") + .type(DECIMAL_LARGE_BYTES_SCHEMA).noDefault() + .name("largeFixedDecimal") + .type(DECIMAL_LARGE_FIXED_SCHEMA).noDefault() + .name("date") + .type(DATE_SCHEMA).noDefault() + .name("timeMillis") + .type(TIME_MILLIS_SCHEMA).noDefault() + .name("timeMicros") + .type(TIME_MICROS_SCHEMA).noDefault() + .name("id") + .type(UUID_SCHEMA).noDefault() + .endRecord(); + + GenericData.Fixed genericSmallFixedDecimal = new GenericData.Fixed(DECIMAL_SMALL_FIXED_SCHEMA); + genericSmallFixedDecimal.bytes(padBigEndianCorrectly(78068160000000L, SMALL_FIXED_SIZE)); + GenericData.Fixed genericLargeFixedDecimal = new GenericData.Fixed(DECIMAL_LARGE_FIXED_SCHEMA); + genericLargeFixedDecimal.bytes(padBigEndianCorrectly(78068160000000L, LARGE_FIXED_SIZE)); + UUID id = UUID.randomUUID(); + + GenericData.Record myRecord = new GenericData.Record(schema); + myRecord.put("timestampMillis", testTime.getTime()); + myRecord.put("timestampMicros", testTime.getTime() * 1000); + myRecord.put("smallBytesDecimal", ByteBuffer.wrap(Longs.toByteArray(78068160000000L))); + myRecord.put("smallFixedDecimal", genericSmallFixedDecimal); + myRecord.put("largeBytesDecimal", ByteBuffer.wrap(Int128.fromBigEndian(Longs.toByteArray(78068160000000L)).toBigEndianBytes())); + myRecord.put("largeFixedDecimal", genericLargeFixedDecimal); + myRecord.put("date", 9035); + myRecord.put("timeMillis", 39_600_000); + myRecord.put("timeMicros", 39_600_000_000L); + myRecord.put("id", id.toString()); + + TrinoInputFile input = createWrittenFileWithData(schema, ImmutableList.of(myRecord)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, schema, new NativeLogicalTypesAvroTypeManager())) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + // Timestamps equal + SqlTimestamp milliTimestamp = (SqlTimestamp) TimestampType.TIMESTAMP_MILLIS.getObjectValue(null, p.getBlock(0), 0); + SqlTimestamp microTimestamp = (SqlTimestamp) TimestampType.TIMESTAMP_MICROS.getObjectValue(null, p.getBlock(1), 0); + assertThat(milliTimestamp).isEqualTo(microTimestamp.roundTo(3)); + assertThat(microTimestamp.getEpochMicros()).isEqualTo(testTime.getTime() * 1000); + + // Decimals Equal + SqlDecimal smallBytesDecimal = (SqlDecimal) SMALL_DECIMAL_TYPE.getObjectValue(null, p.getBlock(2), 0); + SqlDecimal smallFixedDecimal = (SqlDecimal) SMALL_DECIMAL_TYPE.getObjectValue(null, p.getBlock(3), 0); + SqlDecimal largeBytesDecimal = (SqlDecimal) LARGE_DECIMAL_TYPE.getObjectValue(null, p.getBlock(4), 0); + SqlDecimal largeFixedDecimal = (SqlDecimal) LARGE_DECIMAL_TYPE.getObjectValue(null, p.getBlock(5), 0); + + assertThat(smallBytesDecimal).isEqualTo(smallFixedDecimal); + assertThat(largeBytesDecimal).isEqualTo(largeFixedDecimal); + assertThat(smallBytesDecimal.toBigDecimal()).isEqualTo(largeBytesDecimal.toBigDecimal()); + assertThat(smallBytesDecimal.getUnscaledValue()).isEqualTo(new BigInteger(Longs.toByteArray(78068160000000L))); + + // Get date + SqlDate date = (SqlDate) DateType.DATE.getObjectValue(null, p.getBlock(6), 0); + assertThat(date.getDays()).isEqualTo(9035); + + // Time equals + SqlTime timeMillis = (SqlTime) TimeType.TIME_MILLIS.getObjectValue(null, p.getBlock(7), 0); + SqlTime timeMicros = (SqlTime) TimeType.TIME_MICROS.getObjectValue(null, p.getBlock(8), 0); + assertThat(timeMillis).isEqualTo(timeMicros.roundTo(3)); + assertThat(timeMillis.getPicos()).isEqualTo(timeMicros.getPicos()).isEqualTo(39_600_000_000L * 1_000_000L); + + //UUID + assertThat(id.toString()).isEqualTo(UuidType.UUID.getObjectValue(null, p.getBlock(9), 0)); + + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(1); + } + } + + @Test + public void testWithDefaults() + throws IOException, AvroTypeException + { + String id = UUID.randomUUID().toString(); + Schema schema = SchemaBuilder.builder() + .record("testDefaults") + .fields() + .name("timestampMillis") + .type(TIMESTAMP_MILLIS_SCHEMA).withDefault(testTime.getTime()) + .name("smallBytesDecimal") + .type(DECIMAL_SMALL_BYTES_SCHEMA).withDefault(ByteBuffer.wrap(Longs.toByteArray(testTime.getTime()))) + .name("timeMicros") + .type(TIME_MICROS_SCHEMA).withDefault(39_600_000_000L) + .name("id") + .type(UUID_SCHEMA).withDefault(id) + .endRecord(); + Schema writeSchema = SchemaBuilder.builder() + .record("testDefaults") + .fields() + .name("notRead").type().optional().booleanType() + .endRecord(); + + TrinoInputFile input = createWrittenFileWithSchema(10, writeSchema); + try (AvroFileReader avroFileReader = new AvroFileReader(input, schema, new NativeLogicalTypesAvroTypeManager())) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + for (int i = 0; i < p.getPositionCount(); i++) { + // millis timestamp const + SqlTimestamp milliTimestamp = (SqlTimestamp) TimestampType.TIMESTAMP_MILLIS.getObjectValue(null, p.getBlock(0), i); + assertThat(milliTimestamp.getEpochMicros()).isEqualTo(testTime.getTime() * 1000); + + // decimal bytes const + SqlDecimal smallBytesDecimal = (SqlDecimal) SMALL_DECIMAL_TYPE.getObjectValue(null, p.getBlock(1), i); + assertThat(smallBytesDecimal.getUnscaledValue()).isEqualTo(new BigInteger(Longs.toByteArray(testTime.getTime()))); + + // time micros const + SqlTime timeMicros = (SqlTime) TimeType.TIME_MICROS.getObjectValue(null, p.getBlock(2), i); + assertThat(timeMicros.getPicos()).isEqualTo(39_600_000_000L * 1_000_000L); + + //UUID const assert + assertThat(id).isEqualTo(UuidType.UUID.getObjectValue(null, p.getBlock(3), i)); + } + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(10); + } + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java new file mode 100644 index 000000000000..d6bd5de44187 --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestAvroPageDataReaderWithoutTypeManager.java @@ -0,0 +1,563 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.primitives.Bytes; +import com.google.common.primitives.Longs; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.local.LocalFileSystem; +import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlock; +import io.trino.spi.block.Block; +import io.trino.spi.block.ByteArrayBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.MapBlock; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.util.RandomData; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.base.Verify.verify; +import static io.trino.block.BlockAssertions.assertBlockEquals; +import static io.trino.block.BlockAssertions.createIntsBlock; +import static io.trino.block.BlockAssertions.createRowBlock; +import static io.trino.block.BlockAssertions.createStringsBlock; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.within; + +public class TestAvroPageDataReaderWithoutTypeManager +{ + private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final ArrayType ARRAY_INTEGER = new ArrayType(INTEGER); + private static final MapType MAP_VARCHAR_VARCHAR = new MapType(VARCHAR, VARCHAR, TYPE_OPERATORS); + private static final MapType MAP_VARCHAR_INTEGER = new MapType(VARCHAR, INTEGER, TYPE_OPERATORS); + private static final TrinoFileSystem TRINO_LOCAL_FILESYSTEM = new LocalFileSystem(Path.of("/")); + + private static final Schema SIMPLE_RECORD_SCHEMA = SchemaBuilder.record("simpleRecord") + .fields() + .name("a") + .type().intType().noDefault() + .name("b") + .type().doubleType().noDefault() + .name("c") + .type().stringType().noDefault() + .endRecord(); + + private static final Schema SIMPLE_ENUM_SCHEMA = SchemaBuilder.enumeration("myEnumType").symbols("A", "B", "C"); + + private static final Schema SIMPLE_ENUM_SUPER_SCHEMA = SchemaBuilder.enumeration("myEnumType").symbols("A", "B", "C", "D"); + + private static final Schema SIMPLE_ENUM_REORDERED = SchemaBuilder.enumeration("myEnumType").symbols("C", "D", "B", "A"); + + private static final Schema ALL_TYPES_RECORD_SCHEMA = SchemaBuilder.builder() + .record("all") + .fields() + .name("aBoolean") + .type().booleanType().noDefault() + .name("aInt") + .type().intType().noDefault() + .name("aLong") + .type().longType().noDefault() + .name("aFloat") + .type().floatType().noDefault() + .name("aDouble") + .type().doubleType().noDefault() + .name("aString") + .type().stringType().noDefault() + .name("aBytes") + .type().bytesType().noDefault() + .name("aFixed") + .type().fixed("myFixedType").size(16).noDefault() + .name("anArray") + .type().array().items().intType().noDefault() + .name("aMap") + .type().map().values().intType().noDefault() + .name("anEnum") + .type(SIMPLE_ENUM_SCHEMA).noDefault() + .name("aRecord") + .type(SIMPLE_RECORD_SCHEMA).noDefault() + .name("aUnion") + .type().optional().stringType() + .endRecord(); + + private static final GenericRecord ALL_TYPES_GENERIC_RECORD; + private static final GenericRecord SIMPLE_GENERIC_RECORD; + private static final String A_STRING_VALUE = "a test string"; + private static final ByteBuffer A_BYTES_VALUE = ByteBuffer.wrap("a test byte array".getBytes(StandardCharsets.UTF_8)); + private static final GenericData.Fixed A_FIXED_VALUE; + + static { + SIMPLE_GENERIC_RECORD = new GenericData.Record(SIMPLE_RECORD_SCHEMA); + SIMPLE_GENERIC_RECORD.put("a", 5); + SIMPLE_GENERIC_RECORD.put("b", 3.14159265358979); + SIMPLE_GENERIC_RECORD.put("c", "Simple Record String Field"); + + UUID fixed = UUID.nameUUIDFromBytes("a test fixed".getBytes(StandardCharsets.UTF_8)); + A_FIXED_VALUE = new GenericData.Fixed(SchemaBuilder.builder().fixed("myFixedType").size(16), Bytes.concat(Longs.toByteArray(fixed.getMostSignificantBits()), Longs.toByteArray(fixed.getLeastSignificantBits()))); + + ALL_TYPES_GENERIC_RECORD = new GenericData.Record(ALL_TYPES_RECORD_SCHEMA); + ALL_TYPES_GENERIC_RECORD.put("aBoolean", true); + ALL_TYPES_GENERIC_RECORD.put("aInt", 42); + ALL_TYPES_GENERIC_RECORD.put("aLong", 3400L); + ALL_TYPES_GENERIC_RECORD.put("aFloat", 3.14f); + ALL_TYPES_GENERIC_RECORD.put("aDouble", 9.81); + ALL_TYPES_GENERIC_RECORD.put("aString", A_STRING_VALUE); + ALL_TYPES_GENERIC_RECORD.put("aBytes", A_BYTES_VALUE); + ALL_TYPES_GENERIC_RECORD.put("aFixed", A_FIXED_VALUE); + ALL_TYPES_GENERIC_RECORD.put("anArray", ImmutableList.of(1, 2, 3, 4)); + ALL_TYPES_GENERIC_RECORD.put("aMap", ImmutableMap.of("key1", 1, "key2", 2)); + ALL_TYPES_GENERIC_RECORD.put("anEnum", new GenericData.EnumSymbol(SIMPLE_ENUM_SCHEMA, "A")); + ALL_TYPES_GENERIC_RECORD.put("aRecord", SIMPLE_GENERIC_RECORD); + ALL_TYPES_GENERIC_RECORD.put("aUnion", null); + } + + @Test + public void testAllTypesSimple() + throws IOException, AvroTypeException + { + TrinoInputFile input = createWrittenFileWithData(ALL_TYPES_RECORD_SCHEMA, ImmutableList.of(ALL_TYPES_GENERIC_RECORD)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + assertIsAllTypesGenericRecord(p); + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(1); + } + } + + @Test + public void testSchemaWithSkips() + throws IOException, AvroTypeException + { + SchemaBuilder.FieldAssembler fieldAssembler = SchemaBuilder.builder().record("simpleRecord").fields(); + fieldAssembler.name("notInAllTypeRecordSchema").type().optional().array().items().intType(); + Schema readSchema = fieldAssembler.endRecord(); + + int count = ThreadLocalRandom.current().nextInt(10000, 100000); + TrinoInputFile input = createWrittenFileWithSchema(count, ALL_TYPES_RECORD_SCHEMA); + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + for (int pos = 0; pos < p.getPositionCount(); pos++) { + assertThat(p.getBlock(0).isNull(pos)).isTrue(); + } + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(count); + } + } + + @Test + public void testSchemaWithDefaults() + throws IOException, AvroTypeException + { + SchemaBuilder.FieldAssembler fieldAssembler = SchemaBuilder.builder().record("simpleRecord").fields(); + fieldAssembler.name("defaultedField1").type().map().values().stringType().mapDefault(ImmutableMap.of("key1", "value1")); + for (Schema.Field field : SIMPLE_RECORD_SCHEMA.getFields()) { + fieldAssembler = fieldAssembler.name(field.name()).type(field.schema()).noDefault(); + } + fieldAssembler.name("defaultedField2").type().booleanType().booleanDefault(true); + Schema readerSchema = fieldAssembler.endRecord(); + + int count = ThreadLocalRandom.current().nextInt(10000, 100000); + TrinoInputFile input = createWrittenFileWithSchema(count, SIMPLE_RECORD_SCHEMA); + try (AvroFileReader avroFileReader = new AvroFileReader(input, readerSchema, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + MapBlock mb = (MapBlock) p.getBlock(0); + MapBlock expected = (MapBlock) MAP_VARCHAR_VARCHAR.createBlockFromKeyValue(Optional.empty(), + new int[] {0, 1}, + createStringsBlock("key1"), + createStringsBlock("value1")); + mb = (MapBlock) mb.getRegion(0, 1); + assertBlockEquals(MAP_VARCHAR_VARCHAR, mb, expected); + + ByteArrayBlock block = (ByteArrayBlock) p.getBlock(readerSchema.getFields().size() - 1); + assertThat(block.getByte(0, 0)).isGreaterThan((byte) 0); + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(count); + } + } + + @Test + public void testSchemaWithReorders() + throws IOException, AvroTypeException + { + Schema writerSchema = reverseSchema(ALL_TYPES_RECORD_SCHEMA); + TrinoInputFile input = createWrittenFileWithData(writerSchema, ImmutableList.of(reverseGenericRecord(ALL_TYPES_GENERIC_RECORD))); + try (AvroFileReader avroFileReader = new AvroFileReader(input, ALL_TYPES_RECORD_SCHEMA, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + assertIsAllTypesGenericRecord(p); + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(1); + } + } + + @Test + public void testPromotions() + throws IOException, AvroTypeException + { + SchemaBuilder.FieldAssembler writeSchemaBuilder = SchemaBuilder.builder().record("writeRecord").fields(); + SchemaBuilder.FieldAssembler readSchemaBuilder = SchemaBuilder.builder().record("readRecord").fields(); + + AtomicInteger fieldNum = new AtomicInteger(0); + Map> expectedBlockPerChannel = new HashMap<>(); + for (Schema.Type readType : Schema.Type.values()) { + List promotesFrom = switch (readType) { + case STRING -> ImmutableList.of(Schema.Type.BYTES); + case BYTES -> ImmutableList.of(Schema.Type.STRING); + case LONG -> ImmutableList.of(Schema.Type.INT); + case FLOAT -> ImmutableList.of(Schema.Type.INT, Schema.Type.LONG); + case DOUBLE -> ImmutableList.of(Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT); + case RECORD, ENUM, ARRAY, MAP, UNION, FIXED, INT, BOOLEAN, NULL -> ImmutableList.of(); + }; + for (Schema.Type writeType : promotesFrom) { + expectedBlockPerChannel.put(fieldNum.get(), switch (readType) { + case STRING, BYTES -> VariableWidthBlock.class; + case LONG, DOUBLE -> LongArrayBlock.class; + case FLOAT -> IntArrayBlock.class; + case RECORD, ENUM, ARRAY, MAP, UNION, FIXED, INT, BOOLEAN, NULL -> throw new IllegalStateException(); + }); + String fieldName = "field" + fieldNum.getAndIncrement(); + writeSchemaBuilder = writeSchemaBuilder.name(fieldName).type(Schema.create(writeType)).noDefault(); + readSchemaBuilder = readSchemaBuilder.name(fieldName).type(Schema.create(readType)).noDefault(); + } + } + + int count = ThreadLocalRandom.current().nextInt(10000, 100000); + TrinoInputFile input = createWrittenFileWithSchema(count, writeSchemaBuilder.endRecord()); + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchemaBuilder.endRecord(), NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + for (Map.Entry> channelClass : expectedBlockPerChannel.entrySet()) { + assertThat(p.getBlock(channelClass.getKey())).isInstanceOf(channelClass.getValue()); + } + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(count); + } + } + + @Test + public void testEnum() + throws IOException, AvroTypeException + { + Schema base = SchemaBuilder.record("test").fields() + .name("myEnum") + .type(SIMPLE_ENUM_SCHEMA).noDefault() + .endRecord(); + Schema superSchema = SchemaBuilder.record("test").fields() + .name("myEnum") + .type(SIMPLE_ENUM_SUPER_SCHEMA).noDefault() + .endRecord(); + Schema reorderdSchema = SchemaBuilder.record("test").fields() + .name("myEnum") + .type(SIMPLE_ENUM_REORDERED).noDefault() + .endRecord(); + + GenericRecord expected = (GenericRecord) new RandomData(base, 1).iterator().next(); + + //test superset + TrinoInputFile input = createWrittenFileWithData(base, ImmutableList.of(expected)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, superSchema, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + String actualSymbol = new String(((Slice) VARCHAR.getObject(p.getBlock(0), 0)).getBytes(), StandardCharsets.UTF_8); + assertThat(actualSymbol).isEqualTo(expected.get("myEnum").toString()); + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(1); + } + + //test reordered + input = createWrittenFileWithData(base, ImmutableList.of(expected)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, reorderdSchema, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + String actualSymbol = new String(((Slice) VarcharType.VARCHAR.getObject(p.getBlock(0), 0)).getBytes(), StandardCharsets.UTF_8); + assertThat(actualSymbol).isEqualTo(expected.get("myEnum").toString()); + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(1); + } + } + + @Test + public void testCoercionOfUnionToStruct() + throws IOException, AvroTypeException + { + Schema complexUnion = Schema.createUnion(Schema.create(Schema.Type.INT), Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL)); + + Schema readSchema = SchemaBuilder.builder() + .record("testComplexUnions") + .fields() + .name("readStraightUp") + .type(complexUnion) + .noDefault() + .name("readFromReverse") + .type(complexUnion) + .noDefault() + .name("readFromDefault") + .type(complexUnion) + .withDefault(42) + .endRecord(); + + Schema writeSchema = SchemaBuilder.builder() + .record("testComplexUnions") + .fields() + .name("readStraightUp") + .type(complexUnion) + .noDefault() + .name("readFromReverse") + .type(Schema.createUnion(Schema.create(Schema.Type.STRING), Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT))) + .noDefault() + .endRecord(); + + GenericRecord stringsOnly = new GenericData.Record(writeSchema); + stringsOnly.put("readStraightUp", "I am in column 0 field 1"); + stringsOnly.put("readFromReverse", "I am in column 1 field 1"); + + GenericRecord ints = new GenericData.Record(writeSchema); + ints.put("readStraightUp", 5); + ints.put("readFromReverse", 21); + + GenericRecord nulls = new GenericData.Record(writeSchema); + nulls.put("readStraightUp", null); + nulls.put("readFromReverse", null); + + TrinoInputFile input = createWrittenFileWithData(writeSchema, ImmutableList.of(stringsOnly, ints, nulls)); + try (AvroFileReader avroFileReader = new AvroFileReader(input, readSchema, NoOpAvroTypeManager.INSTANCE)) { + int totalRecords = 0; + while (avroFileReader.hasNext()) { + Page p = avroFileReader.next(); + assertThat(p.getPositionCount()).withFailMessage("Page Batch should be at least 3").isEqualTo(3); + //check first column + //check first column first row coerced struct + Block readStraightUpStringsOnly = p.getBlock(0).getSingleValueBlock(0); + assertThat(readStraightUpStringsOnly.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readStraightUpStringsOnly.getChildren().get(1).isNull(0)).isTrue(); // int field null + assertThat(VARCHAR.getObjectValue(null, readStraightUpStringsOnly.getChildren().get(2), 0)).isEqualTo("I am in column 0 field 1"); //string field expected value + // check first column second row coerced struct + Block readStraightUpInts = p.getBlock(0).getSingleValueBlock(1); + assertThat(readStraightUpInts.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readStraightUpInts.getChildren().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readStraightUpInts.getChildren().get(1), 0)).isEqualTo(5); + + //check first column third row is null + assertThat(p.getBlock(0).isNull(2)).isTrue(); + //check second column + //check second column first row coerced struct + Block readFromReverseStringsOnly = p.getBlock(1).getSingleValueBlock(0); + assertThat(readFromReverseStringsOnly.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromReverseStringsOnly.getChildren().get(1).isNull(0)).isTrue(); // int field null + assertThat(VARCHAR.getObjectValue(null, readFromReverseStringsOnly.getChildren().get(2), 0)).isEqualTo("I am in column 1 field 1"); + //check second column second row coerced struct + Block readFromReverseUpInts = p.getBlock(1).getSingleValueBlock(1); + assertThat(readFromReverseUpInts.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromReverseUpInts.getChildren().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromReverseUpInts.getChildren().get(1), 0)).isEqualTo(21); + //check second column third row is null + assertThat(p.getBlock(1).isNull(2)).isTrue(); + + //check third column (default of 42 always) + //check third column first row coerced struct + Block readFromDefaultStringsOnly = p.getBlock(2).getSingleValueBlock(0); + assertThat(readFromDefaultStringsOnly.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromDefaultStringsOnly.getChildren().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultStringsOnly.getChildren().get(1), 0)).isEqualTo(42); + //check third column second row coerced struct + Block readFromDefaultInts = p.getBlock(2).getSingleValueBlock(1); + assertThat(readFromDefaultInts.getChildren().size()).isEqualTo(3); // tag, int and string block fields + assertThat(readFromDefaultInts.getChildren().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultInts.getChildren().get(1), 0)).isEqualTo(42); + //check third column third row coerced struct + Block readFromDefaultNulls = p.getBlock(2).getSingleValueBlock(2); + assertThat(readFromDefaultNulls.getChildren().size()).isEqualTo(3); // int and string block fields + assertThat(readFromDefaultNulls.getChildren().get(2).isNull(0)).isTrue(); // string field null + assertThat(INTEGER.getObjectValue(null, readFromDefaultNulls.getChildren().get(1), 0)).isEqualTo(42); + + totalRecords += p.getPositionCount(); + } + assertThat(totalRecords).isEqualTo(3); + } + } + + protected static TrinoInputFile createWrittenFileWithData(Schema schema, List records) + throws IOException + { + File tempFile = File.createTempFile("testingAvroReading", null); + try (DataFileWriter fileWriter = new DataFileWriter<>(new GenericDatumWriter<>())) { + fileWriter.create(schema, tempFile); + for (GenericRecord genericRecord : records) { + fileWriter.append(genericRecord); + } + } + tempFile.deleteOnExit(); + return TRINO_LOCAL_FILESYSTEM.newInputFile(Location.of("local://" + tempFile.getAbsolutePath())); + } + + protected static TrinoInputFile createWrittenFileWithSchema(int count, Schema schema) + throws IOException + { + Iterator randomData = new RandomData(schema, count).iterator(); + File tempFile = File.createTempFile("testingAvroReading", null); + try (DataFileWriter fileWriter = new DataFileWriter<>(new GenericDatumWriter<>())) { + fileWriter.create(schema, tempFile); + while (randomData.hasNext()) { + fileWriter.append((GenericRecord) randomData.next()); + } + } + tempFile.deleteOnExit(); + return TRINO_LOCAL_FILESYSTEM.newInputFile(Location.of("local://" + tempFile.getAbsolutePath())); + } + + private static GenericRecord reverseGenericRecord(GenericRecord record) + { + Schema reversedSchema = reverseSchema(record.getSchema()); + GenericRecordBuilder recordBuilder = new GenericRecordBuilder(reversedSchema); + for (Schema.Field field : reversedSchema.getFields()) { + if (field.schema().getType() == Schema.Type.RECORD) { + recordBuilder.set(field, reverseGenericRecord((GenericRecord) record.get(field.name()))); + } + else { + recordBuilder.set(field, record.get(field.name())); + } + } + return recordBuilder.build(); + } + + private static Schema reverseSchema(Schema schema) + { + verify(schema.getType() == Schema.Type.RECORD); + SchemaBuilder.FieldAssembler fieldAssembler = SchemaBuilder.builder().record(schema.getName()).fields(); + for (Schema.Field field : Lists.reverse(schema.getFields())) { + if (field.schema().getType() == Schema.Type.ENUM) { + fieldAssembler = fieldAssembler.name(field.name()) + .type(Schema.createEnum(field.schema().getName(), field.schema().getDoc(), field.schema().getNamespace(), Lists.reverse(field.schema().getEnumSymbols()))) + .noDefault(); + } + else if (field.schema().getType() == Schema.Type.UNION) { + fieldAssembler = fieldAssembler.name(field.name()) + .type(Schema.createUnion(Lists.reverse(field.schema().getTypes()))) + .noDefault(); + } + else if (field.schema().getType() == Schema.Type.RECORD) { + fieldAssembler = fieldAssembler.name(field.name()) + .type(reverseSchema(field.schema())) + .noDefault(); + } + else { + fieldAssembler = fieldAssembler.name(field.name()).type(field.schema()).noDefault(); + } + } + return fieldAssembler.endRecord(); + } + + private static void assertIsAllTypesGenericRecord(Page p) + { + // test boolean + assertThat(p.getBlock(0)).isInstanceOf(ByteArrayBlock.class); + assertThat(BooleanType.BOOLEAN.getBoolean(p.getBlock(0), 0)).isTrue(); + // test int + assertThat(p.getBlock(1)).isInstanceOf(IntArrayBlock.class); + assertThat(INTEGER.getInt(p.getBlock(1), 0)).isEqualTo(42); + // test long + assertThat(p.getBlock(2)).isInstanceOf(LongArrayBlock.class); + assertThat(BigintType.BIGINT.getLong(p.getBlock(2), 0)).isEqualTo(3400L); + // test float + assertThat(p.getBlock(3)).isInstanceOf(IntArrayBlock.class); + assertThat(RealType.REAL.getFloat(p.getBlock(3), 0)).isCloseTo(3.14f, within(0.001f)); + // test double + assertThat(p.getBlock(4)).isInstanceOf(LongArrayBlock.class); + assertThat(DoubleType.DOUBLE.getDouble(p.getBlock(4), 0)).isCloseTo(9.81, within(0.001)); + // test string + assertThat(p.getBlock(5)).isInstanceOf(VariableWidthBlock.class); + assertThat(VARCHAR.getObject(p.getBlock(5), 0)).isEqualTo(Slices.utf8Slice(A_STRING_VALUE)); + // test bytes + assertThat(p.getBlock(6)).isInstanceOf(VariableWidthBlock.class); + assertThat(VarbinaryType.VARBINARY.getObject(p.getBlock(6), 0)).isEqualTo(Slices.wrappedBuffer(A_BYTES_VALUE)); + // test fixed + assertThat(p.getBlock(7)).isInstanceOf(VariableWidthBlock.class); + assertThat(VarbinaryType.VARBINARY.getObject(p.getBlock(7), 0)).isEqualTo(Slices.wrappedBuffer(A_FIXED_VALUE.bytes())); + //test array + assertThat(p.getBlock(8)).isInstanceOf(ArrayBlock.class); + assertThat(ARRAY_INTEGER.getObject(p.getBlock(8), 0)).isInstanceOf(IntArrayBlock.class); + assertBlockEquals(INTEGER, ARRAY_INTEGER.getObject(p.getBlock(8), 0), createIntsBlock(1, 2, 3, 4)); + // test map + assertThat(p.getBlock(9)).isInstanceOf(MapBlock.class); + assertThat(MAP_VARCHAR_INTEGER.getObjectValue(null, p.getBlock(9), 0)).isEqualTo(ImmutableMap.of("key1", 1, "key2", 2)); + // test enum + assertThat(p.getBlock(10)).isInstanceOf(VariableWidthBlock.class); + assertThat(VARCHAR.getObject(p.getBlock(10), 0)).isEqualTo(Slices.utf8Slice("A")); + // test record + assertThat(p.getBlock(11)).isInstanceOf(RowBlock.class); + Block expected = createRowBlock(ImmutableList.of(INTEGER, DoubleType.DOUBLE, VARCHAR), new Object[] {5, 3.14159265358979, "Simple Record String Field"}); + assertBlockEquals(RowType.anonymousRow(INTEGER, DoubleType.DOUBLE, VARCHAR), p.getBlock(11), expected); + // test nullable union + assertThat(p.getBlock(12)).isInstanceOf(VariableWidthBlock.class); + assertThat(p.getBlock(12).isNull(0)).isTrue(); + } +} diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestLongFromBigEndian.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestLongFromBigEndian.java new file mode 100644 index 000000000000..6cddcf481ed0 --- /dev/null +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/avro/TestLongFromBigEndian.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.hive.formats.avro; + +import com.google.common.primitives.Longs; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; + +import static com.google.common.base.Verify.verify; +import static io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager.fromBigEndian; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +public class TestLongFromBigEndian +{ + @Test + public void testArrays() + { + assertThat(fromBigEndian(new byte[] {(byte) 0xFF, (byte) 0xFF})).isEqualTo(-1); + assertThat(fromBigEndian(new byte[] {0, 0, 0, 0, 0, 0, (byte) 0xFF, (byte) 0xFF})).isEqualTo(65535); + assertThat(fromBigEndian(new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x80, 0, 0, 0, 0, 0, 0, 0})).isEqualTo(Long.MIN_VALUE); + } + + @Test + public void testIdentity() + { + long a = 780681600000L; + long b = Long.MIN_VALUE; + long c = Long.MAX_VALUE; + long d = 0L; + long e = -1L; + + assertThat(fromBigEndian(Longs.toByteArray(a))).isEqualTo(a); + assertThat(fromBigEndian(Longs.toByteArray(b))).isEqualTo(b); + assertThat(fromBigEndian(Longs.toByteArray(c))).isEqualTo(c); + assertThat(fromBigEndian(Longs.toByteArray(d))).isEqualTo(d); + assertThat(fromBigEndian(Longs.toByteArray(e))).isEqualTo(e); + } + + @Test + public void testLessThan8Bytes() + { + long a = 24L; + long b = -24L; + long c = 0L; + long d = 1L; + long e = -1L; + long f = 64L; + long g = -64L; + + for (int i = 0; i < 8; i++) { + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(a), i, 8))).isEqualTo(a); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(b), i, 8))).isEqualTo(b); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(c), i, 8))).isEqualTo(c); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(d), i, 8))).isEqualTo(d); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(e), i, 8))).isEqualTo(e); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(f), i, 8))).isEqualTo(f); + assertThat(fromBigEndian(Arrays.copyOfRange(Longs.toByteArray(g), i, 8))).isEqualTo(g); + } + } + + public static byte[] padBigEndianCorrectly(long toPad, int totalSize) + { + verify(totalSize >= 8); + byte[] longBytes = Longs.toByteArray(toPad); + + byte[] padded = new byte[totalSize]; + + System.arraycopy(longBytes, 0, padded, totalSize - 8, 8); + if (toPad < 0) { + for (int i = 0; i < totalSize - 8; i++) { + padded[i] = -1; + } + } + return padded; + } + + @Test + public void testWithPadding() + { + long a = 780681600000L; + long b = Long.MIN_VALUE; + long c = Long.MAX_VALUE; + long d = 0L; + long e = -1L; + + for (int i = 9; i < 24; i++) { + assertThat(fromBigEndian(padBigEndianCorrectly(a, i))).isEqualTo(a); + assertThat(fromBigEndian(padBigEndianCorrectly(b, i))).isEqualTo(b); + assertThat(fromBigEndian(padBigEndianCorrectly(c, i))).isEqualTo(c); + assertThat(fromBigEndian(padBigEndianCorrectly(d, i))).isEqualTo(d); + assertThat(fromBigEndian(padBigEndianCorrectly(e, i))).isEqualTo(e); + } + } + + private static byte[] padPoorly(long toPad) + { + int totalSize = 32; + byte[] longBytes = Longs.toByteArray(toPad); + + byte[] padded = new byte[totalSize]; + + System.arraycopy(longBytes, 0, padded, totalSize - 8, 8); + + for (int i = 0; i < totalSize - 8; i++) { + padded[i] = ThreadLocalRandom.current().nextBoolean() ? (byte) ThreadLocalRandom.current().nextInt(1, Byte.MAX_VALUE) : (byte) ThreadLocalRandom.current().nextInt(Byte.MIN_VALUE, -1); + } + + return padded; + } + + @Test + public void testWithBadPadding() + { + long a = 780681600000L; + long b = Long.MIN_VALUE; + long c = Long.MAX_VALUE; + long d = 0L; + long e = -1L; + + assertThatThrownBy(() -> fromBigEndian(padPoorly(a))).isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fromBigEndian(padPoorly(b))).isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fromBigEndian(padPoorly(c))).isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fromBigEndian(padPoorly(d))).isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fromBigEndian(padPoorly(e))).isInstanceOf(ArithmeticException.class); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveFormatsConfig.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveFormatsConfig.java index 5a396802b633..5683d66f1d60 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveFormatsConfig.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveFormatsConfig.java @@ -18,6 +18,7 @@ public class HiveFormatsConfig { + private boolean avroFileNativeReaderEnabled = true; private boolean csvNativeReaderEnabled = true; private boolean csvNativeWriterEnabled = true; private boolean jsonNativeReaderEnabled = true; @@ -30,6 +31,19 @@ public class HiveFormatsConfig private boolean sequenceFileNativeReaderEnabled = true; private boolean sequenceFileNativeWriterEnabled = true; + public boolean isAvroFileNativeReaderEnabled() + { + return avroFileNativeReaderEnabled; + } + + @Config("avro.native-reader.enabled") + @ConfigDescription("Use native Avro file reader") + public HiveFormatsConfig setAvroFileNativeReaderEnabled(boolean avroFileNativeReaderEnabled) + { + this.avroFileNativeReaderEnabled = avroFileNativeReaderEnabled; + return this; + } + public boolean isCsvNativeReaderEnabled() { return csvNativeReaderEnabled; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java index 0616acec9220..511f5d9dc365 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveModule.java @@ -25,6 +25,7 @@ import io.trino.hdfs.TrinoFileSystemCache; import io.trino.hdfs.TrinoFileSystemCacheStats; import io.trino.plugin.base.CatalogName; +import io.trino.plugin.hive.avro.AvroHivePageSourceFactory; import io.trino.plugin.hive.fs.CachingDirectoryLister; import io.trino.plugin.hive.fs.TransactionScopeCachingDirectoryListerFactory; import io.trino.plugin.hive.line.CsvFileWriterFactory; @@ -139,6 +140,7 @@ public void configure(Binder binder) pageSourceFactoryBinder.addBinding().to(OrcPageSourceFactory.class).in(Scopes.SINGLETON); pageSourceFactoryBinder.addBinding().to(ParquetPageSourceFactory.class).in(Scopes.SINGLETON); pageSourceFactoryBinder.addBinding().to(RcFilePageSourceFactory.class).in(Scopes.SINGLETON); + pageSourceFactoryBinder.addBinding().to(AvroHivePageSourceFactory.class).in(Scopes.SINGLETON); Multibinder recordCursorProviderBinder = newSetBinder(binder, HiveRecordCursorProvider.class); recordCursorProviderBinder.addBinding().to(S3SelectRecordCursorProvider.class).in(Scopes.SINGLETON); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java index eedd4f71a50b..f4aff5a01b86 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveSessionProperties.java @@ -65,6 +65,7 @@ public final class HiveSessionProperties private static final String PARALLEL_PARTITIONED_BUCKETED_WRITES = "parallel_partitioned_bucketed_writes"; private static final String FORCE_LOCAL_SCHEDULING = "force_local_scheduling"; private static final String INSERT_EXISTING_PARTITIONS_BEHAVIOR = "insert_existing_partitions_behavior"; + private static final String AVRO_NATIVE_READER_ENABLED = "avro_native_reader_enabled"; private static final String CSV_NATIVE_READER_ENABLED = "csv_native_reader_enabled"; private static final String CSV_NATIVE_WRITER_ENABLED = "csv_native_writer_enabled"; private static final String JSON_NATIVE_READER_ENABLED = "json_native_reader_enabled"; @@ -203,6 +204,11 @@ public HiveSessionProperties( false, value -> InsertExistingPartitionsBehavior.valueOf((String) value, hiveConfig.isImmutablePartitions()), InsertExistingPartitionsBehavior::toString), + booleanProperty( + AVRO_NATIVE_READER_ENABLED, + "Use native Avro file reader", + hiveFormatsConfig.isAvroFileNativeReaderEnabled(), + false), booleanProperty( CSV_NATIVE_READER_ENABLED, "Use native CSV reader", @@ -659,6 +665,11 @@ public static InsertExistingPartitionsBehavior getInsertExistingPartitionsBehavi return session.getProperty(INSERT_EXISTING_PARTITIONS_BEHAVIOR, InsertExistingPartitionsBehavior.class); } + public static boolean isAvroNativeReaderEnabled(ConnectorSession session) + { + return session.getProperty(AVRO_NATIVE_READER_ENABLED, Boolean.class); + } + public static boolean isCsvNativeReaderEnabled(ConnectorSession session) { return session.getProperty(CSV_NATIVE_READER_ENABLED, Boolean.class); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java index 691579009aca..0733174148db 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java @@ -35,15 +35,15 @@ import static com.google.common.base.Strings.lenientFormat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.instanceSize; +import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_FIELD_PREFIX; +import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_NAME; +import static io.trino.hive.formats.UnionToRowCoercionUtils.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; import static io.trino.plugin.hive.type.TypeInfoFactory.getPrimitiveTypeInfo; import static io.trino.plugin.hive.type.TypeInfoUtils.getTypeInfoFromTypeString; import static io.trino.plugin.hive.type.TypeInfoUtils.getTypeInfosFromTypeString; -import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX; -import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME; -import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature; import static io.trino.plugin.hive.util.SerdeConstants.BIGINT_TYPE_NAME; @@ -249,7 +249,7 @@ public Optional getHiveTypeForDereferences(List dereferences) else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { try { if (fieldIndex == 0) { - // union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature} + // union's tag field, defined in {@link io.trino.hive.formats.UnionToRowCoercionUtils} return Optional.of(toHiveType(UNION_FIELD_TAG_TYPE)); } else { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java new file mode 100644 index 000000000000..8bbf18ad7c33 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveConstants.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.avro; + +public final class AvroHiveConstants +{ + private AvroHiveConstants() {} + + // file metadata + public static final String WRITER_TIME_ZONE = "writer.time.zone"; + + //hive table properties + public static final String SCHEMA_LITERAL = "avro.schema.literal"; + public static final String SCHEMA_URL = "avro.schema.url"; + public static final String SCHEMA_NONE = "none"; + public static final String SCHEMA_NAMESPACE = "avro.schema.namespace"; + public static final String SCHEMA_NAME = "avro.schema.name"; + public static final String SCHEMA_DOC = "avro.schema.doc"; + public static final String TABLE_NAME = "name"; + + // Hive Logical types + public static final String CHAR_TYPE_LOGICAL_NAME = "char"; + public static final String VARCHAR_TYPE_LOGICAL_NAME = "varchar"; + public static final String VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP = "maxLength"; +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java new file mode 100644 index 000000000000..a6ca72874e13 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHiveFileUtils.java @@ -0,0 +1,288 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.avro; + +import com.google.common.base.Splitter; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager; +import io.trino.plugin.hive.HiveType; +import io.trino.plugin.hive.type.CharTypeInfo; +import io.trino.plugin.hive.type.DecimalTypeInfo; +import io.trino.plugin.hive.type.ListTypeInfo; +import io.trino.plugin.hive.type.MapTypeInfo; +import io.trino.plugin.hive.type.PrimitiveCategory; +import io.trino.plugin.hive.type.PrimitiveTypeInfo; +import io.trino.plugin.hive.type.StructTypeInfo; +import io.trino.plugin.hive.type.TypeInfo; +import io.trino.plugin.hive.type.UnionTypeInfo; +import io.trino.plugin.hive.type.VarcharTypeInfo; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.trino.plugin.hive.HiveMetadata.TABLE_COMMENT; +import static io.trino.plugin.hive.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_DOC; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_LITERAL; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NAME; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NAMESPACE; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_NONE; +import static io.trino.plugin.hive.avro.AvroHiveConstants.SCHEMA_URL; +import static io.trino.plugin.hive.avro.AvroHiveConstants.TABLE_NAME; +import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP; +import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; +import static io.trino.plugin.hive.util.HiveUtil.getColumnNames; +import static io.trino.plugin.hive.util.HiveUtil.getColumnTypes; +import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMN_COMMENTS; +import static java.util.Collections.emptyList; +import static java.util.function.Predicate.not; + +public final class AvroHiveFileUtils +{ + private final AtomicInteger recordNameSuffix = new AtomicInteger(0); + + private AvroHiveFileUtils() {} + + // Lifted and shifted from org.apache.hadoop.hive.serde2.avro.AvroSerdeUtils.determineSchemaOrThrowException + public static Schema determineSchemaOrThrowException(TrinoFileSystem fileSystem, Properties properties) + throws IOException + { + // Try pull schema from literal table property + String schemaString = properties.getProperty(SCHEMA_LITERAL, ""); + if (!schemaString.isBlank() && !schemaString.equals(SCHEMA_NONE)) { + return getSchemaParser().parse(schemaString); + } + + // Try pull schema directly from URL + String schemaURL = properties.getProperty(SCHEMA_URL, ""); + if (!schemaURL.isBlank()) { + TrinoInputFile schemaFile = fileSystem.newInputFile(Location.of(schemaURL)); + if (!schemaFile.exists()) { + throw new IOException("No avro schema file not found at " + schemaURL); + } + try (TrinoInputStream inputStream = schemaFile.newStream()) { + return getSchemaParser().parse(inputStream); + } + catch (IOException e) { + throw new IOException("Unable to read avro schema file from given path: " + schemaURL, e); + } + } + Schema schema = getSchemaFromProperties(properties); + properties.setProperty(SCHEMA_LITERAL, schema.toString()); + return schema; + } + + private static Schema getSchemaFromProperties(Properties properties) + throws IOException + { + List columnNames = getColumnNames(properties); + List columnTypes = getColumnTypes(properties); + if (columnNames.isEmpty() || columnTypes.isEmpty()) { + throw new IOException("Unable to parse column names or column types from job properties to create Avro Schema"); + } + if (columnNames.size() != columnTypes.size()) { + throw new IllegalArgumentException("Avro Schema initialization failed. Number of column name and column type differs. columnNames = %s, columnTypes = %s".formatted(columnNames, columnTypes)); + } + List columnComments = Optional.ofNullable(properties.getProperty(LIST_COLUMN_COMMENTS)) + .filter(not(String::isBlank)) + .map(Splitter.on('\0')::splitToList) + .orElse(emptyList()); + + final String tableName = properties.getProperty(TABLE_NAME); + final String tableComment = properties.getProperty(TABLE_COMMENT); + + return constructSchemaFromParts( + columnNames, + columnTypes, + columnComments, + Optional.ofNullable(properties.getProperty(SCHEMA_NAMESPACE)), + Optional.ofNullable(properties.getProperty(SCHEMA_NAME, tableName)), + Optional.ofNullable(properties.getProperty(SCHEMA_DOC, tableComment))); + } + + private static Schema constructSchemaFromParts(List columnNames, List columnTypes, + List columnComments, Optional namespace, Optional name, Optional doc) + { + // create instance of this class to keep nested record naming consistent for any given inputs + AvroHiveFileUtils recordIncrementingUtil = new AvroHiveFileUtils(); + SchemaBuilder.RecordBuilder schemaBuilder = SchemaBuilder.record(name.orElse("baseRecord")); + namespace.ifPresent(schemaBuilder::namespace); + doc.ifPresent(schemaBuilder::doc); + SchemaBuilder.FieldAssembler fieldBuilder = schemaBuilder.fields(); + + for (int i = 0; i < columnNames.size(); ++i) { + String comment = columnComments.size() > i ? columnComments.get(i) : null; + Schema fieldSchema = recordIncrementingUtil.avroSchemaForHiveType(columnTypes.get(i)); + fieldBuilder = fieldBuilder + .name(columnNames.get(i)) + .doc(comment) + .type(fieldSchema) + .withDefault(null); + } + return fieldBuilder.endRecord(); + } + + private Schema avroSchemaForHiveType(HiveType hiveType) + { + Schema schema = switch (hiveType.getCategory()) { + case PRIMITIVE -> createAvroPrimitive(hiveType); + case LIST -> { + ListTypeInfo listTypeInfo = (ListTypeInfo) hiveType.getTypeInfo(); + yield Schema.createArray(avroSchemaForHiveType(HiveType.toHiveType(listTypeInfo.getListElementTypeInfo()))); + } + case MAP -> { + MapTypeInfo mapTypeInfo = ((MapTypeInfo) hiveType.getTypeInfo()); + TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + if (!(keyTypeInfo instanceof PrimitiveTypeInfo primitiveKeyTypeInfo) || + primitiveKeyTypeInfo.getPrimitiveCategory() != PrimitiveCategory.STRING) { + throw new UnsupportedOperationException("Key of Map must be a String"); + } + TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + yield Schema.createMap(avroSchemaForHiveType(HiveType.toHiveType(valueTypeInfo))); + } + case STRUCT -> createAvroRecord(hiveType); + case UNION -> { + List childSchemas = new ArrayList<>(); + for (TypeInfo childTypeInfo : ((UnionTypeInfo) hiveType.getTypeInfo()).getAllUnionObjectTypeInfos()) { + final Schema childSchema = avroSchemaForHiveType(HiveType.toHiveType(childTypeInfo)); + if (childSchema.getType() == Schema.Type.UNION) { + childSchemas.addAll(childSchema.getTypes()); + } + else { + childSchemas.add(childSchema); + } + } + yield Schema.createUnion(removeDuplicateNullSchemas(childSchemas)); + } + }; + + return wrapInUnionWithNull(schema); + } + + private static Schema createAvroPrimitive(HiveType hiveType) + { + if (!(hiveType.getTypeInfo() instanceof PrimitiveTypeInfo primitiveTypeInfo)) { + throw new IllegalStateException("HiveType in primitive category must have PrimitiveTypeInfo"); + } + return switch (primitiveTypeInfo.getPrimitiveCategory()) { + case STRING -> Schema.create(Schema.Type.STRING); + case CHAR -> { + Schema charSchema = SchemaBuilder.builder().type(Schema.create(Schema.Type.STRING)); + charSchema.addProp(LogicalType.LOGICAL_TYPE_PROP, CHAR_TYPE_LOGICAL_NAME); + charSchema.addProp(VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP, ((CharTypeInfo) hiveType.getTypeInfo()).getLength()); + yield charSchema; + } + case VARCHAR -> { + Schema varcharSchema = SchemaBuilder.builder().type(Schema.create(Schema.Type.STRING)); + varcharSchema.addProp(LogicalType.LOGICAL_TYPE_PROP, VARCHAR_TYPE_LOGICAL_NAME); + varcharSchema.addProp(VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP, ((VarcharTypeInfo) hiveType.getTypeInfo()).getLength()); + yield varcharSchema; + } + case BINARY -> Schema.create(Schema.Type.BYTES); + case BYTE, SHORT, INT -> Schema.create(Schema.Type.INT); + case LONG -> Schema.create(Schema.Type.LONG); + case FLOAT -> Schema.create(Schema.Type.FLOAT); + case DOUBLE -> Schema.create(Schema.Type.DOUBLE); + case BOOLEAN -> Schema.create(Schema.Type.BOOLEAN); + case DECIMAL -> { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) hiveType.getTypeInfo(); + LogicalTypes.Decimal decimalLogicalType = LogicalTypes.decimal(decimalTypeInfo.precision(), decimalTypeInfo.scale()); + yield decimalLogicalType.addToSchema(Schema.create(Schema.Type.BYTES)); + } + case DATE -> NativeLogicalTypesAvroTypeManager.DATE_SCHEMA; + case TIMESTAMP -> NativeLogicalTypesAvroTypeManager.TIMESTAMP_MILLIS_SCHEMA; + case VOID -> Schema.create(Schema.Type.NULL); + default -> throw new UnsupportedOperationException(hiveType + " is not supported."); + }; + } + + private Schema createAvroRecord(HiveType hiveType) + { + if (!(hiveType.getTypeInfo() instanceof StructTypeInfo structTypeInfo)) { + throw new IllegalStateException("HiveType type info must be Struct Type info to make Avro Record"); + } + + final List allStructFieldNames = + structTypeInfo.getAllStructFieldNames(); + final List allStructFieldTypeInfo = + structTypeInfo.getAllStructFieldTypeInfos(); + if (allStructFieldNames.size() != allStructFieldTypeInfo.size()) { + throw new IllegalArgumentException("Failed to generate avro schema from hive schema. " + + "name and column type differs. names = " + allStructFieldNames + ", types = " + + allStructFieldTypeInfo); + } + + SchemaBuilder.FieldAssembler fieldAssembler = SchemaBuilder + .record("record_" + recordNameSuffix.getAndIncrement()) + .doc(structTypeInfo.toString()) + .fields(); + + for (int i = 0; i < allStructFieldNames.size(); ++i) { + final TypeInfo childTypeInfo = allStructFieldTypeInfo.get(i); + final Schema fieldSchema = avroSchemaForHiveType(HiveType.toHiveType(childTypeInfo)); + fieldAssembler = fieldAssembler + .name(allStructFieldNames.get(i)) + .doc(childTypeInfo.toString()) + .type(fieldSchema) + .withDefault(null); + } + return fieldAssembler.endRecord(); + } + + public static Schema wrapInUnionWithNull(Schema schema) + { + return switch (schema.getType()) { + case NULL -> schema; + case UNION -> Schema.createUnion(removeDuplicateNullSchemas(schema.getTypes())); + default -> Schema.createUnion(Arrays.asList(Schema.create(Schema.Type.NULL), schema)); + }; + } + + private static List removeDuplicateNullSchemas(List childSchemas) + { + List prunedSchemas = new ArrayList<>(); + boolean isNullPresent = false; + for (Schema schema : childSchemas) { + if (schema.getType() == Schema.Type.NULL) { + isNullPresent = true; + } + else { + prunedSchemas.add(schema); + } + } + if (isNullPresent) { + prunedSchemas.add(0, Schema.create(Schema.Type.NULL)); + } + + return prunedSchemas; + } + + private static Schema.Parser getSchemaParser() + { + // HIVE-24797: Disable validate default values when parsing Avro schemas. + return new Schema.Parser().setValidateDefaults(false); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSource.java new file mode 100644 index 000000000000..1b03c8cb2ced --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSource.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.avro; + +import io.airlift.units.DataSize; +import io.trino.filesystem.TrinoInputFile; +import io.trino.hive.formats.avro.AvroFileReader; +import io.trino.hive.formats.avro.AvroTypeException; +import io.trino.hive.formats.avro.AvroTypeManager; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorPageSource; +import org.apache.avro.Schema; + +import java.io.IOException; +import java.util.OptionalLong; + +import static io.trino.plugin.base.util.Closables.closeAllSuppress; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; +import static java.util.Objects.requireNonNull; + +public class AvroHivePageSource + implements ConnectorPageSource +{ + private static final long GUESSED_MEMORY_USAGE = DataSize.of(16, DataSize.Unit.MEGABYTE).toBytes(); + + private final String fileName; + private final AvroFileReader avroFileReader; + + public AvroHivePageSource( + TrinoInputFile inputFile, + Schema schema, + AvroTypeManager avroTypeManager, + long offset, + long length) + throws IOException, AvroTypeException + { + fileName = requireNonNull(inputFile, "inputFile is null").location().fileName(); + avroFileReader = new AvroFileReader(inputFile, schema, avroTypeManager, offset, OptionalLong.of(length)); + } + + @Override + public long getCompletedBytes() + { + return avroFileReader.getCompletedBytes(); + } + + @Override + public long getReadTimeNanos() + { + return avroFileReader.getReadTimeNanos(); + } + + @Override + public boolean isFinished() + { + try { + return !avroFileReader.hasNext(); + } + catch (IOException | RuntimeException e) { + closeAllSuppress(e, this); + throw new TrinoException(HIVE_CURSOR_ERROR, "Failed to read Avro file: " + fileName, e); + } + } + + @Override + public Page getNextPage() + { + try { + if (avroFileReader.hasNext()) { + return avroFileReader.next(); + } + else { + return null; + } + } + catch (IOException | RuntimeException e) { + closeAllSuppress(e, this); + throw new TrinoException(HIVE_CURSOR_ERROR, "Failed to read Avro file: " + fileName, e); + } + } + + @Override + public long getMemoryUsage() + { + return GUESSED_MEMORY_USAGE; + } + + @Override + public void close() + throws IOException + { + avroFileReader.close(); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSourceFactory.java new file mode 100644 index 000000000000..c2a9ba0887bb --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroHivePageSourceFactory.java @@ -0,0 +1,248 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.avro; + +import com.google.inject.Inject; +import io.airlift.slice.Slices; +import io.airlift.units.DataSize; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoFileSystemFactory; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoInputStream; +import io.trino.filesystem.memory.MemoryInputFile; +import io.trino.hive.formats.avro.AvroTypeException; +import io.trino.plugin.hive.AcidInfo; +import io.trino.plugin.hive.FileFormatDataSourceStats; +import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HivePageSourceFactory; +import io.trino.plugin.hive.HiveTimestampPrecision; +import io.trino.plugin.hive.ReaderColumns; +import io.trino.plugin.hive.ReaderPageSource; +import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.plugin.hive.fs.MonitoredInputFile; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.predicate.TupleDomain; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.util.internal.Accessor; +import org.apache.hadoop.conf.Configuration; + +import java.io.IOException; +import java.util.AbstractCollection; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; +import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; +import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; +import static io.trino.plugin.hive.HiveSessionProperties.isAvroNativeReaderEnabled; +import static io.trino.plugin.hive.ReaderPageSource.noProjectionAdaptation; +import static io.trino.plugin.hive.avro.AvroHiveFileUtils.wrapInUnionWithNull; +import static io.trino.plugin.hive.util.HiveClassNames.AVRO_SERDE_CLASS; +import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; +import static io.trino.plugin.hive.util.HiveUtil.splitError; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +public class AvroHivePageSourceFactory + implements HivePageSourceFactory +{ + private static final DataSize BUFFER_SIZE = DataSize.of(8, DataSize.Unit.MEGABYTE); + + private final TrinoFileSystemFactory trinoFileSystemFactory; + private final FileFormatDataSourceStats stats; + + @Inject + public AvroHivePageSourceFactory(TrinoFileSystemFactory trinoFileSystemFactory, FileFormatDataSourceStats stats) + { + this.trinoFileSystemFactory = requireNonNull(trinoFileSystemFactory, "trinoFileSystemFactory is null"); + this.stats = requireNonNull(stats, "stats is null"); + } + + @Override + public Optional createPageSource( + Configuration configuration, + ConnectorSession session, + Location path, + long start, + long length, + long estimatedFileSize, + Properties schema, + List columns, + TupleDomain effectivePredicate, + Optional acidInfo, + OptionalInt bucketNumber, + boolean originalFile, + AcidTransaction transaction) + { + if (!isAvroNativeReaderEnabled(session)) { + return Optional.empty(); + } + else if (!AVRO_SERDE_CLASS.equals(getDeserializerClassName(schema))) { + return Optional.empty(); + } + checkArgument(acidInfo.isEmpty(), "Acid is not supported"); + + List projectedReaderColumns = columns; + Optional readerProjections = projectBaseColumns(columns); + + if (readerProjections.isPresent()) { + projectedReaderColumns = readerProjections.get().get().stream() + .map(HiveColumnHandle.class::cast) + .collect(toImmutableList()); + } + + TrinoFileSystem trinoFileSystem = trinoFileSystemFactory.create(session.getIdentity()); + TrinoInputFile inputFile = new MonitoredInputFile(stats, trinoFileSystem.newInputFile(path)); + + Schema tableSchema; + try { + tableSchema = AvroHiveFileUtils.determineSchemaOrThrowException(trinoFileSystem, schema); + } + catch (IOException | org.apache.avro.AvroTypeException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Unable to load or parse schema", e); + } + + try { + length = min(inputFile.length() - start, length); + if (estimatedFileSize < BUFFER_SIZE.toBytes()) { + try (TrinoInputStream input = inputFile.newStream()) { + byte[] data = input.readAllBytes(); + inputFile = new MemoryInputFile(path, Slices.wrappedBuffer(data)); + } + } + } + catch (TrinoException e) { + throw e; + } + catch (Exception e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, splitError(e, path, start, length), e); + } + + // Split may be empty now that the correct file size is known + if (length <= 0) { + return Optional.of(noProjectionAdaptation(new EmptyPageSource())); + } + + Schema maskedSchema; + try { + maskedSchema = maskColumnsFromTableSchema(projectedReaderColumns, tableSchema); + } + catch (org.apache.avro.AvroTypeException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + } + + HiveTimestampPrecision hiveTimestampPrecision = getTimestampPrecision(session); + if (maskedSchema.getFields().isEmpty()) { + // no non-masked columns to select from partition schema + // hack to return null rows with same total count as underlying data file + // will error if UUID is same name as base column for underlying storage table but should never + // return false data. If file data has f+uuid column in schema then resolution of read null from not null will fail. + SchemaBuilder.FieldAssembler nullSchema = SchemaBuilder.record("null_only").fields(); + for (int i = 0; i < Math.max(projectedReaderColumns.size(), 1); i++) { + String notAColumnName = null; + while (Objects.isNull(notAColumnName) || Objects.nonNull(tableSchema.getField(notAColumnName))) { + notAColumnName = "f" + UUID.randomUUID().toString().replace('-', '_'); + } + nullSchema = nullSchema.name(notAColumnName).type(Schema.create(Schema.Type.NULL)).withDefault(null); + } + try { + return Optional.of(noProjectionAdaptation(new AvroHivePageSource(inputFile, nullSchema.endRecord(), new HiveAvroTypeManager(hiveTimestampPrecision), start, length))); + } + catch (IOException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); + } + catch (AvroTypeException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + } + } + + try { + return Optional.of(new ReaderPageSource(new AvroHivePageSource(inputFile, maskedSchema, new HiveAvroTypeManager(hiveTimestampPrecision), start, length), readerProjections)); + } + catch (IOException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); + } + catch (AvroTypeException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + } + } + + private Schema maskColumnsFromTableSchema(List columns, Schema tableSchema) + { + verify(tableSchema.getType() == Schema.Type.RECORD); + Set maskedColumns = columns.stream().map(HiveColumnHandle::getBaseColumnName).collect(LinkedHashSet::new, HashSet::add, AbstractCollection::addAll); + + SchemaBuilder.FieldAssembler maskedSchema = SchemaBuilder.builder() + .record(tableSchema.getName()) + .namespace(tableSchema.getNamespace()) + .fields(); + + for (String columnName : maskedColumns) { + Schema.Field field = tableSchema.getField(columnName); + if (Objects.isNull(field)) { + continue; + } + if (field.hasDefaultValue()) { + try { + Object defaultObj = Accessor.defaultValue(field); + maskedSchema = maskedSchema + .name(field.name()) + .aliases(field.aliases().toArray(String[]::new)) + .doc(field.doc()) + .type(field.schema()) + .withDefault(defaultObj); + } + catch (org.apache.avro.AvroTypeException e) { + // in order to maintain backwards compatibility invalid defaults are mapped to null + // behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults + // solution is to make the field nullable and default-able to null. Any place default would be used, null will be + if (e.getMessage().contains("Invalid default")) { + maskedSchema = maskedSchema + .name(field.name()) + .aliases(field.aliases().toArray(String[]::new)) + .doc(field.doc()) + .type(wrapInUnionWithNull(field.schema())) + .withDefault(null); + } + else { + throw e; + } + } + } + else { + maskedSchema = maskedSchema + .name(field.name()) + .aliases(field.aliases().toArray(String[]::new)) + .doc(field.doc()) + .type(field.schema()) + .noDefault(); + } + } + return maskedSchema.endRecord(); + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java new file mode 100644 index 000000000000..e79bab784891 --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/HiveAvroTypeManager.java @@ -0,0 +1,206 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.avro; + +import io.airlift.slice.Slices; +import io.trino.hive.formats.avro.AvroTypeException; +import io.trino.hive.formats.avro.NativeLogicalTypesAvroTypeManager; +import io.trino.plugin.hive.HiveTimestampPrecision; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.Chars; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Timestamps; +import io.trino.spi.type.Type; +import io.trino.spi.type.Varchars; +import org.apache.avro.Schema; +import org.joda.time.DateTimeZone; + +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; +import java.util.Map; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; + +import static io.trino.plugin.hive.avro.AvroHiveConstants.CHAR_TYPE_LOGICAL_NAME; +import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP; +import static io.trino.plugin.hive.avro.AvroHiveConstants.VARCHAR_TYPE_LOGICAL_NAME; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; + +public class HiveAvroTypeManager + extends NativeLogicalTypesAvroTypeManager +{ + private final AtomicReference convertToTimezone = new AtomicReference<>(UTC); + private final TimestampType upcastMillisToType; + + public HiveAvroTypeManager(HiveTimestampPrecision hiveTimestampPrecision) + { + upcastMillisToType = createTimestampType(requireNonNull(hiveTimestampPrecision, "hiveTimestampPrecision is null").getPrecision()); + } + + @Override + public void configure(Map fileMetadata) + { + if (fileMetadata.containsKey(AvroHiveConstants.WRITER_TIME_ZONE)) { + convertToTimezone.set(ZoneId.of(new String(fileMetadata.get(AvroHiveConstants.WRITER_TIME_ZONE), StandardCharsets.UTF_8))); + } + else { + // legacy path allows this conversion to be skipped with {@link org.apache.hadoop.conf.Configuration} param + // currently no way to set that configuration in Trino + convertToTimezone.set(TimeZone.getDefault().toZoneId()); + } + } + + @Override + public Optional overrideTypeForSchema(Schema schema) + throws AvroTypeException + { + if (schema.getType() == Schema.Type.NULL) { + // allows of dereference when no base columns from file used + // BooleanType chosen rather arbitrarily to be stuffed with null + // in response to behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults + return Optional.of(BooleanType.BOOLEAN); + } + ValidateLogicalTypeResult result = validateLogicalType(schema); + // mapped in from HiveType translator + // TODO replace with sealed class case match syntax when stable + if (result instanceof NativeLogicalTypesAvroTypeManager.NoLogicalType ignored) { + return Optional.empty(); + } + if (result instanceof NonNativeAvroLogicalType nonNativeAvroLogicalType) { + return switch (nonNativeAvroLogicalType.getLogicalTypeName()) { + case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> Optional.of(getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType)); + default -> Optional.empty(); + }; + } + if (result instanceof NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { + return switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { + case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + } + if (result instanceof NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType validNativeAvroLogicalType) { + return switch (validNativeAvroLogicalType.getLogicalType().getName()) { + case DATE -> super.overrideTypeForSchema(schema); + case TIMESTAMP_MILLIS -> Optional.of(upcastMillisToType); + case DECIMAL -> { + if (schema.getType() == Schema.Type.FIXED) { + // for backwards compatibility + throw new AvroTypeException("Hive does not support fixed decimal types"); + } + yield super.overrideTypeForSchema(schema); + } + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + } + throw new IllegalStateException("Unhandled validate logical type result"); + } + + @Override + public Optional> overrideBuildingFunctionForSchema(Schema schema) + throws AvroTypeException + { + ValidateLogicalTypeResult result = validateLogicalType(schema); + // TODO replace with sealed class case match syntax when stable + if (result instanceof NativeLogicalTypesAvroTypeManager.NoLogicalType ignored) { + return Optional.empty(); + } + if (result instanceof NonNativeAvroLogicalType nonNativeAvroLogicalType) { + return switch (nonNativeAvroLogicalType.getLogicalTypeName()) { + case VARCHAR_TYPE_LOGICAL_NAME, CHAR_TYPE_LOGICAL_NAME -> { + Type type = getHiveLogicalVarCharOrCharType(schema, nonNativeAvroLogicalType); + if (nonNativeAvroLogicalType.getLogicalTypeName().equals(VARCHAR_TYPE_LOGICAL_NAME)) { + yield Optional.of(((blockBuilder, obj) -> { + type.writeSlice(blockBuilder, Varchars.truncateToLength(Slices.utf8Slice(obj.toString()), type)); + })); + } + else { + yield Optional.of(((blockBuilder, obj) -> { + type.writeSlice(blockBuilder, Chars.truncateToLengthAndTrimSpaces(Slices.utf8Slice(obj.toString()), type)); + })); + } + } + default -> Optional.empty(); + }; + } + if (result instanceof NativeLogicalTypesAvroTypeManager.InvalidNativeAvroLogicalType invalidNativeAvroLogicalType) { + return switch (invalidNativeAvroLogicalType.getLogicalTypeName()) { + case TIMESTAMP_MILLIS, DATE, DECIMAL -> throw invalidNativeAvroLogicalType.getCause(); + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + } + if (result instanceof NativeLogicalTypesAvroTypeManager.ValidNativeAvroLogicalType validNativeAvroLogicalType) { + return switch (validNativeAvroLogicalType.getLogicalType().getName()) { + case TIMESTAMP_MILLIS -> { + if (upcastMillisToType.isShort()) { + yield Optional.of((blockBuilder, obj) -> { + Long millisSinceEpochUTC = (Long) obj; + upcastMillisToType.writeLong(blockBuilder, DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND); + }); + } + else { + yield Optional.of((blockBuilder, obj) -> { + Long millisSinceEpochUTC = (Long) obj; + LongTimestamp longTimestamp = new LongTimestamp(DateTimeZone.forTimeZone(TimeZone.getTimeZone(convertToTimezone.get())).convertUTCToLocal(millisSinceEpochUTC) * Timestamps.MICROSECONDS_PER_MILLISECOND, 0); + upcastMillisToType.writeObject(blockBuilder, longTimestamp); + }); + } + } + case DATE, DECIMAL -> super.overrideBuildingFunctionForSchema(schema); + default -> Optional.empty(); // Other logical types ignored by hive/don't map to hive types + }; + } + throw new IllegalStateException("Unhandled validate logical type result"); + } + + private static Type getHiveLogicalVarCharOrCharType(Schema schema, NonNativeAvroLogicalType nonNativeAvroLogicalType) + throws AvroTypeException + { + if (schema.getType() != Schema.Type.STRING) { + throw new AvroTypeException("Unsupported Avro type for Hive Logical Type in schema " + schema); + } + Object maxLengthObject = schema.getObjectProp(VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP); + if (maxLengthObject == null) { + throw new AvroTypeException("Missing property maxLength in schema for Hive Type " + nonNativeAvroLogicalType.getLogicalTypeName()); + } + try { + int maxLength; + if (maxLengthObject instanceof String maxLengthString) { + maxLength = Integer.parseInt(maxLengthString); + } + else if (maxLengthObject instanceof Number maxLengthNumber) { + maxLength = maxLengthNumber.intValue(); + } + else { + throw new AvroTypeException("Unrecognized property type for " + VARCHAR_AND_CHAR_LOGICAL_TYPE_LENGTH_PROP + " in schema " + schema); + } + if (nonNativeAvroLogicalType.getLogicalTypeName().equals(VARCHAR_TYPE_LOGICAL_NAME)) { + return createVarcharType(maxLength); + } + else { + return createCharType(maxLength); + } + } + catch (NumberFormatException numberFormatException) { + throw new AvroTypeException("Property maxLength not convertible to Integer in Hive Logical type schema " + schema); + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java index f59bd348a4ac..39b6017a9e9e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java @@ -58,8 +58,8 @@ import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; import static io.trino.plugin.hive.util.HiveUtil.getFooterCount; import static io.trino.plugin.hive.util.HiveUtil.getHeaderCount; +import static io.trino.plugin.hive.util.HiveUtil.splitError; import static java.lang.Math.min; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public abstract class LinePageSourceFactory @@ -179,9 +179,4 @@ public Optional createPageSource( throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, splitError(e, path, start, length), e); } } - - private static String splitError(Throwable t, Location path, long start, long length) - { - return format("Error opening Hive split %s (offset=%s, length=%s): %s", path, start, length, t.getMessage()); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java index 30b45d601cec..b6007ec28c6b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java @@ -100,6 +100,7 @@ import static io.trino.plugin.hive.util.AcidTables.isFullAcidTable; import static io.trino.plugin.hive.util.HiveClassNames.ORC_SERDE_CLASS; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; +import static io.trino.plugin.hive.util.HiveUtil.splitError; import static io.trino.plugin.hive.util.SerdeConstants.SERIALIZATION_LIB; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; @@ -540,11 +541,6 @@ private static boolean hasOriginalFiles(AcidInfo acidInfo) return !acidInfo.getOriginalFiles().isEmpty(); } - private static String splitError(Throwable t, Location path, long start, long length) - { - return format("Error opening Hive split %s (offset=%s, length=%s): %s", path, start, length, t.getMessage()); - } - private static void verifyFileHasColumnNames(List columns, Location path) { if (!columns.isEmpty() && columns.stream().map(OrcColumn::getColumnName).allMatch(physicalColumnName -> DEFAULT_HIVE_COLUMN_NAME_PATTERN.matcher(physicalColumnName).matches())) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java index 4a93461ca8f8..766f0eeb8c1c 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java @@ -66,9 +66,9 @@ import static io.trino.plugin.hive.util.HiveClassNames.COLUMNAR_SERDE_CLASS; import static io.trino.plugin.hive.util.HiveClassNames.LAZY_BINARY_COLUMNAR_SERDE_CLASS; import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName; +import static io.trino.plugin.hive.util.HiveUtil.splitError; import static io.trino.plugin.hive.util.SerdeConstants.SERIALIZATION_LIB; import static java.lang.Math.min; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class RcFilePageSourceFactory @@ -193,9 +193,4 @@ else if (deserializerClassName.equals(COLUMNAR_SERDE_CLASS)) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, message, e); } } - - private static String splitError(Throwable t, Location path, long start, long length) - { - return format("Error opening Hive split %s (offset=%s, length=%s): %s", path, start, length, t.getMessage()); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 1b44a0c8d466..4cb1e870632e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -45,6 +45,7 @@ import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.hive.formats.UnionToRowCoercionUtils.rowTypeSignatureForUnionOfTypes; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; import static io.trino.plugin.hive.HiveType.HIVE_BINARY; import static io.trino.plugin.hive.HiveType.HIVE_BOOLEAN; @@ -79,7 +80,6 @@ import static io.trino.spi.type.TypeSignature.arrayType; import static io.trino.spi.type.TypeSignature.mapType; import static io.trino.spi.type.TypeSignature.rowType; -import static io.trino.spi.type.TypeSignatureParameter.namedField; import static io.trino.spi.type.TypeSignatureParameter.typeParameter; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; @@ -91,10 +91,6 @@ public final class HiveTypeTranslator { private HiveTypeTranslator() {} - public static final String UNION_FIELD_TAG_NAME = "tag"; - public static final String UNION_FIELD_FIELD_PREFIX = "field"; - public static final Type UNION_FIELD_TAG_TYPE = TINYINT; - public static TypeInfo toTypeInfo(Type type) { requireNonNull(type, "type is null"); @@ -212,13 +208,9 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec // Use a row type to represent a union type in Hive for reading UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; List unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos(); - ImmutableList.Builder typeSignatures = ImmutableList.builder(); - typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); - for (int i = 0; i < unionObjectTypes.size(); i++) { - TypeInfo unionObjectType = unionObjectTypes.get(i); - typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision))); - } - return rowType(typeSignatures.build()); + return rowTypeSignatureForUnionOfTypes(unionObjectTypes.stream() + .map(unionObjectType -> toTypeSignature(unionObjectType, timestampPrecision)) + .collect(toImmutableList())); } throw new TrinoException(NOT_SUPPORTED, format("Unsupported Hive type: %s", typeInfo)); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java index bf46770d37f2..f259387c9600 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveUtil.java @@ -25,6 +25,7 @@ import io.airlift.compress.lzo.LzopCodec; import io.airlift.slice.Slice; import io.airlift.slice.SliceUtf8; +import io.trino.filesystem.Location; import io.trino.hadoop.TextLineLengthLimitExceededException; import io.trino.hive.formats.compression.CompressionKind; import io.trino.orc.OrcWriterOptions; @@ -208,6 +209,11 @@ public final class HiveUtil private static final CharMatcher DOT_MATCHER = CharMatcher.is('.'); + public static String splitError(Throwable t, Location location, long start, long length) + { + return format("Error opening Hive split %s (offset=%s, length=%s): %s", location, start, length, t.getMessage()); + } + static { DateTimeParser[] timestampWithoutTimeZoneParser = { DateTimeFormat.forPattern("yyyy-M-d").getParser(), diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerdeConstants.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerdeConstants.java index 5967fe0c91ae..0193ed8e9fcd 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerdeConstants.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/SerdeConstants.java @@ -33,6 +33,7 @@ public final class SerdeConstants public static final String LIST_COLUMNS = "columns"; public static final String LIST_COLUMN_TYPES = "columns.types"; + public static final String LIST_COLUMN_COMMENTS = "columns.comments"; public static final String COLUMN_NAME_DELIMITER = "column.name.delimiter"; diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index f8729f5d16ac..355e40b3ecc1 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -5086,6 +5086,30 @@ public void testAvroTypeValidation() assertQueryFails("CREATE TABLE test_avro_types WITH (format = 'AVRO') AS SELECT cast(42 AS smallint) z", "Column 'z' is smallint, which is not supported by Avro. Use integer instead."); } + @Test + public void testAvroTimestampUpCasting() + { + @Language("SQL") String createTable = "CREATE TABLE test_avro_timestamp_upcasting WITH (format = 'AVRO') AS SELECT TIMESTAMP '1994-09-27 11:23:45.678' my_timestamp"; + + //avro only stores as millis + assertUpdate(createTable, 1); + + // access with multiple precisions + assertQuery(withTimestampPrecision(getSession(), HiveTimestampPrecision.MILLISECONDS), + "SELECT * from test_avro_timestamp_upcasting", + "VALUES (TIMESTAMP '1994-09-27 11:23:45.678')"); + + // access with multiple precisions + assertQuery(withTimestampPrecision(getSession(), HiveTimestampPrecision.MICROSECONDS), + "SELECT * from test_avro_timestamp_upcasting", + "VALUES (TIMESTAMP '1994-09-27 11:23:45.678000')"); + + // access with multiple precisions + assertQuery(withTimestampPrecision(getSession(), HiveTimestampPrecision.NANOSECONDS), + "SELECT * from test_avro_timestamp_upcasting", + "VALUES (TIMESTAMP '1994-09-27 11:23:45.678000000')"); + } + @Test public void testOrderByChar() { @@ -5503,10 +5527,10 @@ private void schemaMismatchesWithDereferenceProjections(HiveStorageFormat format // eg. table column: a row(c varchar, b bigint), partition column: a row(b bigint, c varchar) try { assertUpdate("CREATE TABLE evolve_test (dummy bigint, a row(b bigint, c varchar), d bigint) with (format = '" + format + "', partitioned_by=array['d'])"); - assertUpdate("INSERT INTO evolve_test values (1, row(1, 'abc'), 1)", 1); + assertUpdate("INSERT INTO evolve_test values (10, row(1, 'abc'), 1)", 1); assertUpdate("ALTER TABLE evolve_test DROP COLUMN a"); assertUpdate("ALTER TABLE evolve_test ADD COLUMN a row(c varchar, b bigint)"); - assertUpdate("INSERT INTO evolve_test values (2, row('def', 2), 2)", 1); + assertUpdate("INSERT INTO evolve_test values (20, row('def', 2), 2)", 1); assertQueryFails("SELECT a.b FROM evolve_test where d = 1", ".*There is a mismatch between the table and partition schemas.*"); } finally { @@ -5517,10 +5541,10 @@ private void schemaMismatchesWithDereferenceProjections(HiveStorageFormat format // i.e. "a.c" produces null for rows that were inserted before type of "a" was changed try { assertUpdate("CREATE TABLE evolve_test (dummy bigint, a row(b bigint), d bigint) with (format = '" + format + "', partitioned_by=array['d'])"); - assertUpdate("INSERT INTO evolve_test values (1, row(1), 1)", 1); + assertUpdate("INSERT INTO evolve_test values (10, row(1), 1)", 1); assertUpdate("ALTER TABLE evolve_test DROP COLUMN a"); assertUpdate("ALTER TABLE evolve_test ADD COLUMN a row(b bigint, c varchar)"); - assertUpdate("INSERT INTO evolve_test values (2, row(2, 'def'), 2)", 1); + assertUpdate("INSERT INTO evolve_test values (20, row(2, 'def'), 2)", 1); assertQuery("SELECT a.c FROM evolve_test", "SELECT 'def' UNION SELECT null"); } finally { @@ -5530,10 +5554,10 @@ private void schemaMismatchesWithDereferenceProjections(HiveStorageFormat format // Verify field access when the row evolves without changes to field type try { assertUpdate("CREATE TABLE evolve_test (dummy bigint, a row(b bigint, c varchar), d bigint) with (format = '" + format + "', partitioned_by=array['d'])"); - assertUpdate("INSERT INTO evolve_test values (1, row(1, 'abc'), 1)", 1); + assertUpdate("INSERT INTO evolve_test values (10, row(1, 'abc'), 1)", 1); assertUpdate("ALTER TABLE evolve_test DROP COLUMN a"); assertUpdate("ALTER TABLE evolve_test ADD COLUMN a row(b bigint, c varchar, e int)"); - assertUpdate("INSERT INTO evolve_test values (2, row(2, 'def', 2), 2)", 1); + assertUpdate("INSERT INTO evolve_test values (20, row(2, 'def', 2), 2)", 1); assertQuery("SELECT a.b FROM evolve_test", "VALUES 1, 2"); } finally { @@ -5545,7 +5569,7 @@ private void schemaMismatchesWithDereferenceProjections(HiveStorageFormat format public void testSubfieldReordering() { // Validate for formats for which subfield access is name based - List formats = ImmutableList.of(HiveStorageFormat.ORC, HiveStorageFormat.PARQUET); + List formats = ImmutableList.of(HiveStorageFormat.ORC, HiveStorageFormat.PARQUET, HiveStorageFormat.AVRO); for (HiveStorageFormat format : formats) { // Subfields reordered in the file are read correctly. e.g. if partition column type is row(b bigint, c varchar) but the file diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index 7414db5793a3..4533c3bdd19d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -24,6 +24,7 @@ import io.trino.hive.formats.compression.CompressionKind; import io.trino.orc.OrcReaderOptions; import io.trino.orc.OrcWriterOptions; +import io.trino.plugin.hive.avro.AvroHivePageSourceFactory; import io.trino.plugin.hive.line.CsvFileWriterFactory; import io.trino.plugin.hive.line.CsvPageSourceFactory; import io.trino.plugin.hive.line.JsonFileWriterFactory; @@ -432,6 +433,7 @@ public void testAvro(int rowCount, long fileSizePadding) .withColumns(getTestColumnsSupportedByAvro()) .withRowsCount(rowCount) .withFileSizePadding(fileSizePadding) + .isReadableByPageSource(new AvroHivePageSourceFactory(FILE_SYSTEM_FACTORY, STATS)) .isReadableByRecordCursor(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); } @@ -448,6 +450,7 @@ public void testAvroFileInSymlinkTable(int rowCount) splitProperties.setProperty(FILE_INPUT_FORMAT, SymlinkTextInputFormat.class.getName()); splitProperties.setProperty(SERIALIZATION_LIB, AVRO.getSerde()); testCursorProvider(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT), split, splitProperties, getTestColumnsSupportedByAvro(), SESSION, file.length(), rowCount); + testPageSourceFactory(new AvroHivePageSourceFactory(FILE_SYSTEM_FACTORY, STATS), split, AVRO, getTestColumnsSupportedByAvro(), SESSION, file.length(), rowCount); } finally { //noinspection ResultOfMethodCallIgnored @@ -594,7 +597,8 @@ public void testTruncateVarcharColumn() assertThatFileFormat(AVRO) .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) - .isReadableByRecordCursor(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); + .isReadableByRecordCursor(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .isReadableByPageSource(new AvroHivePageSourceFactory(FILE_SYSTEM_FACTORY, STATS)); assertThatFileFormat(SEQUENCEFILE) .withWriteColumns(ImmutableList.of(writeColumn)) @@ -631,7 +635,8 @@ public void testAvroProjectedColumns(int rowCount) .withWriteColumns(writeColumns) .withReadColumns(readColumns) .withRowsCount(rowCount) - .isReadableByRecordCursorPageSource(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)); + .isReadableByRecordCursorPageSource(createGenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)) + .isReadableByPageSource(new AvroHivePageSourceFactory(FILE_SYSTEM_FACTORY, STATS)); } @Test(dataProvider = "rowCount") diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFormatsConfig.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFormatsConfig.java index d2ec74bdb540..687b14917a2c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFormatsConfig.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFormatsConfig.java @@ -28,6 +28,7 @@ public class TestHiveFormatsConfig public void testDefaults() { assertRecordedDefaults(recordDefaults(HiveFormatsConfig.class) + .setAvroFileNativeReaderEnabled(true) .setCsvNativeReaderEnabled(true) .setCsvNativeWriterEnabled(true) .setJsonNativeReaderEnabled(true) @@ -45,6 +46,7 @@ public void testDefaults() public void testExplicitPropertyMappings() { Map properties = ImmutableMap.builder() + .put("avro.native-reader.enabled", "false") .put("csv.native-reader.enabled", "false") .put("csv.native-writer.enabled", "false") .put("json.native-reader.enabled", "false") @@ -59,6 +61,7 @@ public void testExplicitPropertyMappings() .buildOrThrow(); HiveFormatsConfig expected = new HiveFormatsConfig() + .setAvroFileNativeReaderEnabled(false) .setCsvNativeReaderEnabled(false) .setCsvNativeWriterEnabled(false) .setJsonNativeReaderEnabled(false) diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestAvroSchemaEvolution.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestAvroSchemaEvolution.java index d62c601f4214..7a59f8f1161b 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestAvroSchemaEvolution.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/BaseTestAvroSchemaEvolution.java @@ -155,9 +155,9 @@ public void testSchemaEvolutionWithIncompatibleType() alterTableSchemaLiteral(readSchemaLiteralFromUrl(INCOMPATIBLE_TYPE_SCHEMA)); assertQueryFailure(() -> onTrino().executeQuery(format(selectStarStatement, tableWithSchemaUrl))) - .hasMessageContaining("Found int, expecting string"); + .hasStackTraceContaining("Found int, expecting string"); assertQueryFailure(() -> onTrino().executeQuery(format(selectStarStatement, tableWithSchemaLiteral))) - .hasMessageContaining("Found int, expecting string"); + .hasStackTraceContaining("Found int, expecting string"); } @Test(groups = AVRO) diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java index 44284e3c4903..32034879c0ab 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java @@ -25,6 +25,7 @@ import java.util.List; import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.tests.product.TestGroups.AVRO; import static io.trino.tests.product.TestGroups.SMOKE; import static io.trino.tests.product.utils.QueryExecutors.onHive; import static io.trino.tests.product.utils.QueryExecutors.onTrino; @@ -133,7 +134,7 @@ public static Object[][] unionDereferenceTestCases() "DROP TABLE IF EXISTS " + tableUnionDereference}}; } - @Test(dataProvider = "storage_formats", groups = SMOKE) + @Test(dataProvider = "storage_formats", groups = {SMOKE, AVRO}) public void testReadUniontype(String storageFormat) { // According to testing results, the Hive INSERT queries here only work in Hive 1.2 @@ -219,7 +220,7 @@ public void testReadUniontype(String storageFormat) } } - @Test(dataProvider = "union_dereference_test_cases", groups = SMOKE) + @Test(dataProvider = "union_dereference_test_cases", groups = {SMOKE, AVRO}) public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List expectedResult, String selectTagSql, List expectedTagResult, String dropTableSql) { // According to testing results, the Hive INSERT queries here only work in Hive 1.2 @@ -238,7 +239,7 @@ public void testReadUniontypeWithDereference(String createTableSql, String inser onTrino().executeQuery(dropTableSql); } - @Test(dataProvider = "storage_formats", groups = SMOKE) + @Test(dataProvider = "storage_formats", groups = {SMOKE, AVRO}) public void testUnionTypeSchemaEvolution(String storageFormat) { // According to testing results, the Hive INSERT queries here only work in Hive 1.2