diff --git a/java/core/src/java/org/apache/orc/PhysicalWriter.java b/java/core/src/java/org/apache/orc/PhysicalWriter.java index e25e81c046..872d5dcde7 100644 --- a/java/core/src/java/org/apache/orc/PhysicalWriter.java +++ b/java/core/src/java/org/apache/orc/PhysicalWriter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -22,6 +22,8 @@ import java.nio.ByteBuffer; import org.apache.orc.impl.StreamName; +import org.apache.orc.impl.writer.StreamOptions; +import org.apache.orc.impl.writer.WriterEncryptionVariant; /** * This interface separates the physical layout of ORC files from the higher @@ -39,7 +41,6 @@ interface OutputReceiver { * Output the given buffer to the final destination * * @param buffer the buffer to output - * @throws IOException */ void output(ByteBuffer buffer) throws IOException; @@ -48,16 +49,15 @@ interface OutputReceiver { */ void suppress(); } + /** * Writes the header of the file, which consists of the magic "ORC" bytes. - * @throws IOException */ void writeHeader() throws IOException; /** * Create an OutputReceiver for the given name. * @param name the name of the stream - * @throws IOException */ OutputReceiver createDataStream(StreamName name) throws IOException; @@ -65,7 +65,6 @@ interface OutputReceiver { * Write an index in the given stream name. * @param name the name of the stream * @param index the bloom filter to write - * @param codec the compression codec to use */ void writeIndex(StreamName name, OrcProto.RowIndex.Builder index) throws IOException; @@ -74,7 +73,6 @@ void writeIndex(StreamName name, * Write a bloom filter index in the given stream name. * @param name the name of the stream * @param bloom the bloom filter to write - * @param codec the compression codec to use */ void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom) throws IOException; @@ -89,6 +87,16 @@ void writeBloomFilter(StreamName name, void finalizeStripe(OrcProto.StripeFooter.Builder footer, OrcProto.StripeInformation.Builder dirEntry) throws IOException; + /** + * Write a stripe or file statistics to the file. + * @param name the name of the stream + * @param statistics the statistics to write + * @throws IOException + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) throws IOException; + /** * Writes out the file metadata. * @param builder Metadata builder to finalize and write. @@ -122,19 +130,24 @@ void finalizeStripe(OrcProto.StripeFooter.Builder footer, * @param stripe Stripe data buffer. * @param dirEntry File metadata entry for the stripe, to be updated with * relevant data. - * @throws IOException */ void appendRawStripe(ByteBuffer stripe, OrcProto.StripeInformation.Builder dirEntry ) throws IOException; - /** Gets a compression codec used by this writer. */ - CompressionCodec getCompressionCodec(); - /** * Get the number of bytes for a file in a givem column. * @param column column from which to get file size + * @param variant the encryption variant to check * @return number of bytes for the given column */ - long getFileBytes(int column); + long getFileBytes(int column, WriterEncryptionVariant variant); + + /** + * Get the unencrypted stream options for this file. This class needs the + * stream options to write the indexes and footers. + * + * Additionally, the LLAP CacheWriter wants to disable the generic compression. + */ + StreamOptions getStreamOptions(); } diff --git a/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java b/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java index 1b1cd10474..044271cad5 100644 --- a/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java +++ b/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; public class BitFieldWriter { private RunLengthByteWriter output; @@ -70,4 +71,8 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.estimateMemory(); } + + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/IntegerWriter.java b/java/core/src/java/org/apache/orc/impl/IntegerWriter.java index 70b16d3e9f..19e843b09d 100644 --- a/java/core/src/java/org/apache/orc/impl/IntegerWriter.java +++ b/java/core/src/java/org/apache/orc/impl/IntegerWriter.java @@ -19,6 +19,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * Interface for writing integers. @@ -50,4 +51,6 @@ public interface IntegerWriter { * @return number of bytes */ long estimateMemory(); + + void changeIv(Consumer modifier); } diff --git a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java index aee16b16db..4736a63052 100644 --- a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java +++ b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java @@ -22,19 +22,25 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.TreeMap; +import com.google.protobuf.ByteString; import com.google.protobuf.CodedOutputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.orc.CompressionCodec; +import org.apache.orc.EncryptionVariant; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.writer.WriterEncryptionKey; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.apache.orc.impl.writer.StreamOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +52,12 @@ public class PhysicalFsWriter implements PhysicalWriter { private static final int HDFS_BUFFER_SIZE = 256 * 1024; private FSDataOutputStream rawWriter; + private final DirectStream rawStream; + // the compressed metadata information outStream - private OutStream writer; + private OutStream compressStream; // a protobuf outStream around streamFactory - private CodedOutputStream protobufWriter; + private CodedOutputStream codedCompressStream; private final Path path; private final HadoopShims shims; @@ -59,10 +67,7 @@ public class PhysicalFsWriter implements PhysicalWriter { private final OrcFile.CompressionStrategy compressionStrategy; private final boolean addBlockPadding; private final boolean writeVariableLengthBlocks; - - // the streams that make up the current stripe - private final Map streams = - new TreeMap<>(); + private final VariantTracker unencrypted; private long headerLength; private long stripeStart; @@ -70,11 +75,24 @@ public class PhysicalFsWriter implements PhysicalWriter { // natural blocks private long blockOffset; private int metadataLength; + private int stripeStatisticsLength = 0; private int footerLength; + private int stripeNumber = 0; + + private final Map variants = new TreeMap<>(); + + public PhysicalFsWriter(FileSystem fs, + Path path, + OrcFile.WriterOptions opts + ) throws IOException { + this(fs, path, opts, new WriterEncryptionVariant[0]); + } public PhysicalFsWriter(FileSystem fs, Path path, - OrcFile.WriterOptions opts) throws IOException { + OrcFile.WriterOptions opts, + WriterEncryptionVariant[] encryption + ) throws IOException { this.path = path; long defaultStripeSize = opts.getStripeSize(); this.addBlockPadding = opts.getBlockPadding(); @@ -98,16 +116,124 @@ public PhysicalFsWriter(FileSystem fs, rawWriter = fs.create(path, opts.getOverwrite(), HDFS_BUFFER_SIZE, fs.getDefaultReplication(path), blockSize); blockOffset = 0; - writer = new OutStream("metadata", compress, - new DirectStream(rawWriter)); - protobufWriter = CodedOutputStream.newInstance(writer); + unencrypted = new VariantTracker(opts.getSchema(), compress); writeVariableLengthBlocks = opts.getWriteVariableLengthBlocks(); shims = opts.getHadoopShims(); + rawStream = new DirectStream(rawWriter); + compressStream = new OutStream("stripe footer", compress, rawStream); + codedCompressStream = CodedOutputStream.newInstance(compressStream); + for(WriterEncryptionVariant variant: encryption) { + WriterEncryptionKey key = variant.getKeyDescription(); + StreamOptions encryptOptions = + new StreamOptions(unencrypted.options) + .withEncryption(key.getAlgorithm(), variant.getFileFooterKey()); + variants.put(variant, new VariantTracker(variant.getRoot(), encryptOptions)); + } } - @Override - public CompressionCodec getCompressionCodec() { - return compress.getCodec(); + /** + * Record the information about each column encryption variant. + * The unencrypted data and each encrypted column root are variants. + */ + protected static class VariantTracker { + // the streams that make up the current stripe + protected final Map streams = new TreeMap<>(); + private final int rootColumn; + private final int lastColumn; + protected final StreamOptions options; + // a list for each column covered by this variant + // the elements in the list correspond to each stripe in the file + protected final List[] stripeStats; + protected final List stripeStatsStreams = new ArrayList<>(); + protected final OrcProto.ColumnStatistics[] fileStats; + + VariantTracker(TypeDescription schema, StreamOptions options) { + rootColumn = schema.getId(); + lastColumn = schema.getMaximumId(); + this.options = options; + stripeStats = new List[schema.getMaximumId() - schema.getId() + 1]; + for(int i=0; i < stripeStats.length; ++i) { + stripeStats[i] = new ArrayList<>(); + } + fileStats = new OrcProto.ColumnStatistics[stripeStats.length]; + } + + public BufferedStream createStream(StreamName name) { + BufferedStream result = new BufferedStream(); + streams.put(name, result); + return result; + } + + /** + * Place the streams in the appropriate area while updating the sizes + * with the number of bytes in the area. + * @param area the area to write + * @param sizes the sizes of the areas + * @return the list of stream descriptions to add + */ + public List placeStreams(StreamName.Area area, + SizeCounters sizes) { + List result = new ArrayList<>(streams.size()); + for(Map.Entry stream: streams.entrySet()) { + StreamName name = stream.getKey(); + BufferedStream bytes = stream.getValue(); + if (name.getArea() == area && !bytes.isSuppressed) { + OrcProto.Stream.Builder builder = OrcProto.Stream.newBuilder(); + long size = bytes.getOutputSize(); + if (area == StreamName.Area.INDEX) { + sizes.index += size; + } else { + sizes.data += size; + } + builder.setColumn(name.getColumn()) + .setKind(name.getKind()) + .setLength(size); + result.add(builder.build()); + } + } + return result; + } + + /** + * Write the streams in the appropriate area. + * @param area the area to write + * @param raw the raw stream to write to + */ + public void writeStreams(StreamName.Area area, + FSDataOutputStream raw) throws IOException { + for(Map.Entry stream: streams.entrySet()) { + if (stream.getKey().getArea() == area) { + stream.getValue().spillToDiskAndClear(raw); + } + } + } + + /** + * Computed the size of the given column on disk for this stripe. + * It excludes the index streams. + * @param column a column id + * @return the total number of bytes + */ + public long getFileBytes(int column) { + long result = 0; + if (column >= rootColumn && column <= lastColumn) { + for(Map.Entry entry: streams.entrySet()) { + StreamName name = entry.getKey(); + if (name.getColumn() == column && + name.getArea() != StreamName.Area.INDEX) { + result += entry.getValue().getOutputSize(); + } + } + } + return result; + } + } + + VariantTracker getVariant(EncryptionVariant column) { + if (column == null) { + return unencrypted; + } + return variants.get(column); } /** @@ -120,20 +246,13 @@ public CompressionCodec getCompressionCodec() { * @return number of bytes for the given column */ @Override - public long getFileBytes(final int column) { - long size = 0; - for (final Map.Entry pair: streams.entrySet()) { - final BufferedStream receiver = pair.getValue(); - if(!receiver.isSuppressed) { - - final StreamName name = pair.getKey(); - if(name.getColumn() == column && name.getArea() != StreamName.Area.INDEX ) { - size += receiver.getOutputSize(); - } - } + public long getFileBytes(int column, WriterEncryptionVariant variant) { + return getVariant(variant).getFileBytes(column); + } - } - return size; + @Override + public StreamOptions getStreamOptions() { + return unencrypted.options; } private static final byte[] ZEROS = new byte[64*1024]; @@ -198,36 +317,139 @@ public void suppress() { } private void writeStripeFooter(OrcProto.StripeFooter footer, - long dataSize, - long indexSize, + SizeCounters sizes, OrcProto.StripeInformation.Builder dirEntry) throws IOException { - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); dirEntry.setOffset(stripeStart); - dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - dataSize - indexSize); + dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - sizes.total()); + } + + /** + * Write the saved encrypted stripe statistic in a variant out to the file. + * The streams that are written are added to the tracker.stripeStatsStreams. + * @param output the file we are writing to + * @param stripeNumber the number of stripes in the file + * @param tracker the variant to write out + */ + static void writeEncryptedStripeStatistics(DirectStream output, + int stripeNumber, + VariantTracker tracker + ) throws IOException { + StreamOptions options = new StreamOptions(tracker.options); + tracker.stripeStatsStreams.clear(); + for(int col = tracker.rootColumn; + col < tracker.rootColumn + tracker.stripeStats.length; ++col) { + options.modifyIv(CryptoUtils.modifyIvForStream(col, + OrcProto.Stream.Kind.STRIPE_STATISTICS, stripeNumber)); + OutStream stream = new OutStream("stripe stats for " + col, + options, output); + OrcProto.ColumnarStripeStatistics stats = + OrcProto.ColumnarStripeStatistics.newBuilder() + .addAllColStats(tracker.stripeStats[col - tracker.rootColumn]) + .build(); + long start = output.output.getPos(); + stats.writeTo(stream); + stream.flush(); + OrcProto.Stream description = OrcProto.Stream.newBuilder() + .setColumn(col) + .setKind(OrcProto.Stream.Kind.STRIPE_STATISTICS) + .setLength(output.output.getPos() - start) + .build(); + tracker.stripeStatsStreams.add(description); + } + } + + /** + * Merge the saved unencrypted stripe statistics into the Metadata section + * of the footer. + * @param builder the Metadata section of the file + * @param stripeCount the number of stripes in the file + * @param stats the stripe statistics + */ + static void setUnencryptedStripeStatistics(OrcProto.Metadata.Builder builder, + int stripeCount, + List[] stats) { + // Make the unencrypted stripe stats into lists of StripeStatistics. + builder.clearStripeStats(); + for(int s=0; s < stripeCount; ++s) { + OrcProto.StripeStatistics.Builder stripeStats = + OrcProto.StripeStatistics.newBuilder(); + for(List col: stats) { + stripeStats.addColStats(col.get(s)); + } + builder.addStripeStats(stripeStats.build()); + } + } + + static void setEncryptionStatistics(OrcProto.Encryption.Builder encryption, + int stripeNumber, + Collection variants + ) throws IOException { + int v = 0; + for(VariantTracker variant: variants) { + OrcProto.EncryptionVariant.Builder variantBuilder = + encryption.getVariantsBuilder(v++); + + // Add the stripe statistics streams to the variant description. + variantBuilder.clearStripeStatistics(); + variantBuilder.addAllStripeStatistics(variant.stripeStatsStreams); + + // Serialize and encrypt the file statistics. + OrcProto.FileStatistics.Builder file = OrcProto.FileStatistics.newBuilder(); + for(OrcProto.ColumnStatistics col: variant.fileStats) { + file.addColumn(col); + } + StreamOptions options = new StreamOptions(variant.options); + options.modifyIv(CryptoUtils.modifyIvForStream(variant.rootColumn, + OrcProto.Stream.Kind.FILE_STATISTICS, stripeNumber)); + BufferedStream buffer = new BufferedStream(); + OutStream stream = new OutStream("stats for " + variant, options, buffer); + file.build().writeTo(stream); + stream.flush(); + variantBuilder.setFileStatistics(buffer.getBytes()); + } } @Override public void writeFileMetadata(OrcProto.Metadata.Builder builder) throws IOException { - long startPosn = rawWriter.getPos(); - OrcProto.Metadata metadata = builder.build(); - metadata.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); - this.metadataLength = (int) (rawWriter.getPos() - startPosn); + long stripeStatisticsStart = rawWriter.getPos(); + for(VariantTracker variant: variants.values()) { + writeEncryptedStripeStatistics(rawStream, stripeNumber, variant); + } + setUnencryptedStripeStatistics(builder, stripeNumber, unencrypted.stripeStats); + long metadataStart = rawWriter.getPos(); + builder.build().writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); + this.stripeStatisticsLength = (int) (metadataStart - stripeStatisticsStart); + this.metadataLength = (int) (rawWriter.getPos() - metadataStart); + } + + static void addUnencryptedStatistics(OrcProto.Footer.Builder builder, + OrcProto.ColumnStatistics[] stats) { + for(OrcProto.ColumnStatistics stat: stats) { + builder.addStatistics(stat); + } } @Override public void writeFileFooter(OrcProto.Footer.Builder builder) throws IOException { - long bodyLength = rawWriter.getPos() - metadataLength; + if (variants.size() > 0) { + OrcProto.Encryption.Builder encryption = builder.getEncryptionBuilder(); + setEncryptionStatistics(encryption, stripeNumber, variants.values()); + builder.setStripeStatisticsLength(stripeStatisticsLength); + } + addUnencryptedStatistics(builder, unencrypted.fileStats); + long bodyLength = rawWriter.getPos() - metadataLength - stripeStatisticsLength; builder.setContentLength(bodyLength); builder.setHeaderLength(headerLength); long startPosn = rawWriter.getPos(); OrcProto.Footer footer = builder.build(); - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); this.footerLength = (int) (rawWriter.getPos() - startPosn); } @@ -300,7 +522,7 @@ public void appendRawStripe(ByteBuffer buffer, * data as buffers fill up and stores them in the output list. When the * stripe is being written, the whole stream is written to the file. */ - private static final class BufferedStream implements OutputReceiver { + static final class BufferedStream implements OutputReceiver { private boolean isSuppressed = false; private final List output = new ArrayList<>(); @@ -319,17 +541,56 @@ public void suppress() { /** * Write any saved buffers to the OutputStream if needed, and clears all the * buffers. + * @return true if the stream was written */ - void spillToDiskAndClear(FSDataOutputStream raw - ) throws IOException { + boolean spillToDiskAndClear(FSDataOutputStream raw) throws IOException { if (!isSuppressed) { for (ByteBuffer buffer: output) { raw.write(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); } output.clear(); + return true; } isSuppressed = false; + return false; + } + + /** + * Get the buffer as a protobuf ByteString and clears the BufferedStream. + * @return the bytes + */ + ByteString getBytes() { + int len = output.size(); + if (len == 0) { + return ByteString.EMPTY; + } else { + ByteString result = ByteString.copyFrom(output.get(0)); + for (int i=1; i < output.size(); ++i) { + result = result.concat(ByteString.copyFrom(output.get(i))); + } + output.clear(); + return result; + } + } + + /** + * Get the stream as a ByteBuffer and clear it. + * @return a single ByteBuffer with the contents of the stream + */ + ByteBuffer getByteBuffer() { + ByteBuffer result; + if (output.size() == 1) { + result = output.get(0); + } else { + result = ByteBuffer.allocate((int) getOutputSize()); + for (ByteBuffer buffer : output) { + result.put(buffer); + } + output.clear(); + result.flip(); + } + return result; } /** @@ -347,38 +608,86 @@ public long getOutputSize() { } } + static class SizeCounters { + long index = 0; + long data = 0; + + long total() { + return index + data; + } + } + + void buildStreamList(OrcProto.StripeFooter.Builder footerBuilder, + SizeCounters sizes + ) throws IOException { + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.INDEX, sizes)); + final long unencryptedIndexSize = sizes.index; + int v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.INDEX, sizes)); + } + if (sizes.index != unencryptedIndexSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_INDEX) + .setLength(sizes.index - unencryptedIndexSize)); + } + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.DATA, sizes)); + final long unencryptedDataSize = sizes.data; + v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.DATA, sizes)); + } + if (sizes.data != unencryptedDataSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_DATA) + .setLength(sizes.data - unencryptedDataSize)); + } + } + @Override public void finalizeStripe(OrcProto.StripeFooter.Builder footerBuilder, OrcProto.StripeInformation.Builder dirEntry ) throws IOException { - long indexSize = 0; - long dataSize = 0; - for (Map.Entry pair: streams.entrySet()) { - BufferedStream receiver = pair.getValue(); - if (!receiver.isSuppressed) { - long streamSize = receiver.getOutputSize(); - StreamName name = pair.getKey(); - footerBuilder.addStreams(OrcProto.Stream.newBuilder().setColumn(name.getColumn()) - .setKind(name.getKind()).setLength(streamSize)); - if (StreamName.Area.INDEX == name.getArea()) { - indexSize += streamSize; - } else { - dataSize += streamSize; - } - } - } - dirEntry.setIndexLength(indexSize).setDataLength(dataSize); + SizeCounters sizes = new SizeCounters(); + buildStreamList(footerBuilder, sizes); OrcProto.StripeFooter footer = footerBuilder.build(); + // Do we need to pad the file so the stripe doesn't straddle a block boundary? - padStripe(indexSize + dataSize + footer.getSerializedSize()); + padStripe(sizes.total() + footer.getSerializedSize()); + + // write the unencrypted index streams + unencrypted.writeStreams(StreamName.Area.INDEX, rawWriter); + // write the encrypted index streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.INDEX, rawWriter); + } - // write out the data streams - for (Map.Entry pair : streams.entrySet()) { - pair.getValue().spillToDiskAndClear(rawWriter); + // write the unencrypted data streams + unencrypted.writeStreams(StreamName.Area.DATA, rawWriter); + // write out the unencrypted data streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.DATA, rawWriter); } + // Write out the footer. - writeStripeFooter(footer, dataSize, indexSize, dirEntry); + writeStripeFooter(footer, sizes, dirEntry); + + // fill in the data sizes + dirEntry.setDataLength(sizes.data); + dirEntry.setIndexLength(sizes.index); + + stripeNumber += 1; } @Override @@ -389,10 +698,11 @@ public void writeHeader() throws IOException { @Override public BufferedStream createDataStream(StreamName name) { - BufferedStream result = streams.get(name); + VariantTracker variant = getVariant(name.getEncryption()); + BufferedStream result = variant.streams.get(name); if (result == null) { result = new BufferedStream(); - streams.put(name, result); + variant.streams.put(name, result); } return result; } @@ -402,11 +712,26 @@ private StreamOptions getOptions(OrcProto.Stream.Kind kind) { kind); } + protected OutputStream createIndexStream(StreamName name) { + BufferedStream buffer = createDataStream(name); + VariantTracker tracker = getVariant(name.getEncryption()); + StreamOptions options = + SerializationUtils.getCustomizedCodec(tracker.options, + compressionStrategy, name.getKind()); + if (options.isEncrypted()) { + if (options == tracker.options) { + options = new StreamOptions(options); + } + options.modifyIv(CryptoUtils.modifyIvForStream(name, stripeNumber)); + } + return new OutStream(name.toString(), options, buffer); + } + @Override public void writeIndex(StreamName name, - OrcProto.RowIndex.Builder index) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OrcProto.RowIndex.Builder index + ) throws IOException { + OutputStream stream = createIndexStream(name); index.build().writeTo(stream); stream.flush(); } @@ -415,12 +740,25 @@ public void writeIndex(StreamName name, public void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom ) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OutputStream stream = createIndexStream(name); bloom.build().writeTo(stream); stream.flush(); } + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) { + VariantTracker tracker = getVariant(name.getEncryption()); + if (name.getKind() == OrcProto.Stream.Kind.FILE_STATISTICS) { + tracker.fileStats[name.getColumn() - tracker.rootColumn] = + statistics.build(); + } else { + tracker.stripeStats[name.getColumn() - tracker.rootColumn] + .add(statistics.build()); + } + } + @Override public String toString() { return path.toString(); diff --git a/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java b/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java index d412939dba..fd6561fa66 100644 --- a/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java +++ b/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.OutputStream; +import java.util.function.Consumer; public abstract class PositionedOutputStream extends OutputStream { @@ -36,4 +37,11 @@ public abstract void getPosition(PositionRecorder recorder * @return the number of bytes used by buffers. */ public abstract long getBufferSize(); + + /** + * Change the current Initialization Vector (IV) for the encryption. + * Has no effect if the stream is not encrypted. + * @param modifier a function to modify the IV in place + */ + public abstract void changeIv(Consumer modifier); } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java b/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java index c2f1fa74da..bfa1d7a048 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * A streamFactory that writes a sequence of bytes. A control byte is written before @@ -107,4 +108,8 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize() + MAX_LITERAL_SIZE; } + + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java index 88b47e6a6d..710f493f8d 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * A streamFactory that writes a sequence of integers. A control byte is written before @@ -144,4 +145,9 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize(); } + + @Override + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java index e4c2a051e6..9107774ee8 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** *

A writer that performs light weight compression over sequence of integers. @@ -823,4 +824,9 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize(); } + + @Override + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java b/java/core/src/java/org/apache/orc/impl/WriterImpl.java index 639f963398..7f9cb63254 100644 --- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java +++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java @@ -20,11 +20,15 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.SortedMap; +import java.util.SortedSet; import java.util.TimeZone; import java.util.TreeMap; +import java.util.TreeSet; import io.airlift.compress.lz4.Lz4Compressor; import io.airlift.compress.lz4.Lz4Decompressor; @@ -33,6 +37,7 @@ import org.apache.orc.ColumnStatistics; import org.apache.orc.CompressionCodec; import org.apache.orc.CompressionKind; +import org.apache.orc.DataMask; import org.apache.orc.MemoryManager; import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; @@ -41,7 +46,8 @@ import org.apache.orc.PhysicalWriter; import org.apache.orc.StripeInformation; import org.apache.orc.TypeDescription; -import org.apache.orc.Writer; +import org.apache.orc.impl.writer.WriterEncryptionKey; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.apache.orc.impl.writer.StreamOptions; import org.apache.orc.impl.writer.TreeWriter; import org.apache.orc.impl.writer.WriterContext; @@ -76,16 +82,17 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private static final Logger LOG = LoggerFactory.getLogger(WriterImpl.class); + private static final HadoopShims SHIMS = HadoopShimsFactory.get(); private static final int MIN_ROW_INDEX_STRIDE = 1000; private final Path path; private long adjustedStripeSize; private final int rowIndexStride; - private final StreamOptions compress; private final TypeDescription schema; private final PhysicalWriter physicalWriter; private final OrcFile.WriterVersion writerVersion; + private final StreamOptions unencryptedOptions; private long rowCount = 0; private long rowsInStripe = 0; @@ -95,8 +102,6 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private int stripesAtLastFlush = -1; private final List stripes = new ArrayList<>(); - private final OrcProto.Metadata.Builder fileMetadata = - OrcProto.Metadata.newBuilder(); private final Map userMetadata = new TreeMap<>(); private final TreeWriter treeWriter; @@ -115,14 +120,55 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private final boolean useUTCTimeZone; private final double dictionaryKeySizeThreshold; private final boolean[] directEncodingColumns; + private final List unencryptedEncodings = + new ArrayList<>(); + + // the list of maskDescriptions, keys, and variants + private SortedSet maskDescriptions = new TreeSet<>(); + private SortedMap keys = new TreeMap<>(); + private final WriterEncryptionVariant[] encryption; + // the mapping of columns to maskDescriptions + private final MaskDescriptionImpl[] columnMaskDescriptions; + // the mapping of columns to EncryptionVariants + private final WriterEncryptionVariant[] columnEncryption; + private HadoopShims.KeyProvider keyProvider; + // do we need to include the current encryption keys in the next stripe + // information + private boolean needKeyFlush; public WriterImpl(FileSystem fs, Path path, OrcFile.WriterOptions opts) throws IOException { + this.schema = opts.getSchema(); + int numColumns = schema.getMaximumId() + 1; + if (!opts.isEnforceBufferSize()) { + opts.bufferSize(getEstimatedBufferSize(opts.getStripeSize(), numColumns, + opts.getBufferSize())); + } + + // Do we have column encryption? + List encryptionOptions = opts.getEncryption(); + columnEncryption = new WriterEncryptionVariant[numColumns]; + if (encryptionOptions.isEmpty()) { + columnMaskDescriptions = null; + encryption = new WriterEncryptionVariant[0]; + needKeyFlush = false; + } else { + columnMaskDescriptions = new MaskDescriptionImpl[numColumns]; + encryption = setupEncryption(opts.getKeyProvider(), encryptionOptions); + needKeyFlush = true; + } + + // Set up the physical writer + this.physicalWriter = opts.getPhysicalWriter() == null ? + new PhysicalFsWriter(fs, path, opts, encryption) : + opts.getPhysicalWriter(); + unencryptedOptions = physicalWriter.getStreamOptions(); + OutStream.assertBufferSizeValid(unencryptedOptions.getBufferSize()); + this.path = path; this.conf = opts.getConfiguration(); this.callback = opts.getCallback(); - this.schema = opts.getSchema(); this.writerVersion = opts.getWriterVersion(); bloomFilterVersion = opts.getBloomFilterVersion(); this.directEncodingColumns = OrcUtils.includeColumns( @@ -130,13 +176,7 @@ public WriterImpl(FileSystem fs, dictionaryKeySizeThreshold = OrcConf.DICTIONARY_KEY_SIZE_THRESHOLD.getDouble(conf); if (callback != null) { - callbackContext = new OrcFile.WriterContext(){ - - @Override - public Writer getWriter() { - return WriterImpl.this; - } - }; + callbackContext = () -> WriterImpl.this; } else { callbackContext = null; } @@ -149,22 +189,6 @@ public Writer getWriter() { this.rowIndexStride = opts.getRowIndexStride(); this.memoryManager = opts.getMemoryManager(); buildIndex = rowIndexStride > 0; - int numColumns = schema.getMaximumId() + 1; - if (opts.isEnforceBufferSize()) { - OutStream.assertBufferSizeValid(opts.getBufferSize()); - compress = new StreamOptions(opts.getBufferSize()); - } else { - compress = new StreamOptions(getEstimatedBufferSize(adjustedStripeSize, - numColumns, opts.getBufferSize())); - } - this.physicalWriter = opts.getPhysicalWriter() == null - ? new PhysicalFsWriter(fs, path, opts) - : opts.getPhysicalWriter(); - physicalWriter.writeHeader(); - CompressionCodec codec = physicalWriter.getCompressionCodec(); - if (codec != null) { - compress.withCodec(codec, codec.getDefaultOptions()); - } if (version == OrcFile.Version.FUTURE) { throw new IllegalArgumentException("Can not write in a unknown version."); } else if (version == OrcFile.Version.UNSTABLE_PRE_2_0) { @@ -180,17 +204,17 @@ public Writer getWriter() { OrcUtils.includeColumns(opts.getBloomFilterColumns(), schema); } this.bloomFilterFpp = opts.getBloomFilterFpp(); - treeWriter = TreeWriter.Factory.create(schema, new StreamFactory(), false); + physicalWriter.writeHeader(); + + treeWriter = TreeWriter.Factory.create(schema, null, new StreamFactory()); if (buildIndex && rowIndexStride < MIN_ROW_INDEX_STRIDE) { throw new IllegalArgumentException("Row stride must be at least " + MIN_ROW_INDEX_STRIDE); } - // ensure that we are able to handle callbacks before we register ourselves memoryManager.addWriter(path, opts.getStripeSize(), this); - LOG.info("ORC writer created for path: {} with stripeSize: {} blockSize: {}" + - " compression: {}", path, adjustedStripeSize, opts.getBlockSize(), - compress); + LOG.info("ORC writer created for path: {} with stripeSize: {} options: {}", + path, adjustedStripeSize, unencryptedOptions); } //@VisibleForTesting @@ -207,8 +231,8 @@ public static int getEstimatedBufferSize(long stripeSize, int numColumns, @Override public void increaseCompressionSize(int newSize) { - if (newSize > compress.getBufferSize()) { - compress.bufferSize(newSize); + if (newSize > unencryptedOptions.getBufferSize()) { + unencryptedOptions.bufferSize(newSize); } } @@ -277,19 +301,26 @@ public boolean checkMemory(double newScale) throws IOException { * that the TreeWriters have into the Writer. */ private class StreamFactory implements WriterContext { + /** * Create a stream to store part of a column. - * @param column the column id for the stream - * @param kind the kind of stream + * @param name the name for the stream * @return The output outStream that the section needs to be written to. */ - public OutStream createStream(int column, - OrcProto.Stream.Kind kind - ) throws IOException { - final StreamName name = new StreamName(column, kind); - return new OutStream(physicalWriter.toString(), - SerializationUtils.getCustomizedCodec(compress, compressionStrategy, kind), - physicalWriter.createDataStream(name)); + public OutStream createStream(StreamName name) throws IOException { + StreamOptions options = SerializationUtils.getCustomizedCodec( + unencryptedOptions, compressionStrategy, name.getKind()); + WriterEncryptionVariant encryption = + (WriterEncryptionVariant) name.getEncryption(); + if (encryption != null) { + if (options == unencryptedOptions) { + options = new StreamOptions(options); + } + options.withEncryption(encryption.getKeyDescription().getAlgorithm(), + encryption.getFileFooterKey()) + .modifyIv(CryptoUtils.modifyIvForStream(name, 1)); + } + return new OutStream(name, options, physicalWriter.createDataStream(name)); } /** @@ -312,7 +343,7 @@ public boolean buildIndex() { * @return are the streams compressed */ public boolean isCompressed() { - return physicalWriter.getCompressionCodec() != null; + return unencryptedOptions.getCodec() != null; } /** @@ -379,6 +410,37 @@ public void writeBloomFilter(StreamName name, physicalWriter.writeBloomFilter(name, bloom); } + @Override + public WriterEncryptionVariant getEncryption(int columnId) { + return columnId < columnEncryption.length ? + columnEncryption[columnId] : null; + } + + @Override + public DataMask getUnencryptedMask(int columnId) { + MaskDescriptionImpl descr = columnMaskDescriptions[columnId]; + return descr == null ? null : + DataMask.Factory.build(descr, schema.findSubtype(columnId), + (type) -> columnMaskDescriptions[type.getId()]); + } + + @Override + public void setEncoding(int column, WriterEncryptionVariant encryption, + OrcProto.ColumnEncoding encoding) { + if (encryption == null) { + unencryptedEncodings.add(encoding); + } else { + encryption.addEncoding(encoding); + } + } + + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException { + physicalWriter.writeStatistics(name, stats); + } + public boolean getUseUTCTimestamp() { return useUTCTimeZone; } @@ -399,6 +461,19 @@ private void createRowIndexEntry() throws IOException { rowsInIndex = 0; } + /** + * Write the encrypted keys into the StripeInformation along with the + * stripe id, so that the readers can decrypt the data. + * @param dirEntry the entry to modify + */ + private void addEncryptedKeys(OrcProto.StripeInformation.Builder dirEntry) { + for(WriterEncryptionVariant variant: encryption) { + dirEntry.addEncryptedLocalKeys(ByteString.copyFrom( + variant.getMaterial().getEncryptedKey())); + } + dirEntry.setEncryptStripeId(1 + stripes.size()); + } + private void flushStripe() throws IOException { if (buildIndex && rowsInIndex != 0) { createRowIndexEntry(); @@ -419,18 +494,27 @@ private void flushStripe() throws IOException { builder.setWriterTimezone(TimeZone.getDefault().getID()); } } - OrcProto.StripeStatistics.Builder stats = - OrcProto.StripeStatistics.newBuilder(); - treeWriter.flushStreams(); - treeWriter.writeStripe(builder, stats, requiredIndexEntries); - + treeWriter.writeStripe(requiredIndexEntries); + // update the encodings + builder.addAllColumns(unencryptedEncodings); + unencryptedEncodings.clear(); + for (WriterEncryptionVariant writerEncryptionVariant : encryption) { + OrcProto.StripeEncryptionVariant.Builder encrypt = + OrcProto.StripeEncryptionVariant.newBuilder(); + encrypt.addAllEncoding(writerEncryptionVariant.getEncodings()); + writerEncryptionVariant.clearEncodings(); + builder.addEncryption(encrypt); + } OrcProto.StripeInformation.Builder dirEntry = OrcProto.StripeInformation.newBuilder() .setNumberOfRows(rowsInStripe); + if (encryption.length > 0 && needKeyFlush) { + addEncryptedKeys(dirEntry); + needKeyFlush = false; + } physicalWriter.finalizeStripe(builder, dirEntry); - fileMetadata.addStripeStats(stats.build()); stripes.add(dirEntry.build()); rowCount += rowsInStripe; rowsInStripe = 0; @@ -453,32 +537,73 @@ private OrcProto.CompressionKind writeCompressionKind(CompressionKind kind) { } } - private void writeFileStatistics(OrcProto.Footer.Builder builder, - TreeWriter writer) throws IOException { - writer.writeFileStatistics(builder); - } - private void writeMetadata() throws IOException { - physicalWriter.writeFileMetadata(fileMetadata); + // The physical writer now has the stripe statistics, so we pass a + // new builder in here. + physicalWriter.writeFileMetadata(OrcProto.Metadata.newBuilder()); } private long writePostScript() throws IOException { - CompressionCodec codec = compress.getCodec(); OrcProto.PostScript.Builder builder = OrcProto.PostScript.newBuilder() - .setCompression(writeCompressionKind(codec == null - ? CompressionKind.NONE - : codec.getKind())) .setMagic(OrcFile.MAGIC) .addVersion(version.getMajor()) .addVersion(version.getMinor()) .setWriterVersion(writerVersion.getId()); - if (compress.getCodec() != null) { - builder.setCompressionBlockSize(compress.getBufferSize()); + CompressionCodec codec = unencryptedOptions.getCodec(); + if (codec == null) { + builder.setCompression(OrcProto.CompressionKind.NONE); + } else { + builder.setCompression(writeCompressionKind(codec.getKind())) + .setCompressionBlockSize(unencryptedOptions.getBufferSize()); } return physicalWriter.writePostScript(builder); } + private OrcProto.EncryptionKey.Builder writeEncryptionKey(WriterEncryptionKey key) { + OrcProto.EncryptionKey.Builder result = OrcProto.EncryptionKey.newBuilder(); + HadoopShims.KeyMetadata meta = key.getMetadata(); + result.setKeyName(meta.getKeyName()); + result.setKeyVersion(meta.getVersion()); + result.setAlgorithm(OrcProto.EncryptionAlgorithm.valueOf( + meta.getAlgorithm().getSerialization())); + return result; + } + + private OrcProto.EncryptionVariant.Builder + writeEncryptionVariant(WriterEncryptionVariant variant) { + OrcProto.EncryptionVariant.Builder result = + OrcProto.EncryptionVariant.newBuilder(); + result.setRoot(variant.getRoot().getId()); + result.setKey(variant.getKeyDescription().getId()); + result.setEncryptedKey(ByteString.copyFrom(variant.getMaterial().getEncryptedKey())); + return result; + } + + private OrcProto.Encryption.Builder writeEncryptionFooter() { + OrcProto.Encryption.Builder encrypt = OrcProto.Encryption.newBuilder(); + for(MaskDescriptionImpl mask: maskDescriptions) { + OrcProto.DataMask.Builder maskBuilder = OrcProto.DataMask.newBuilder(); + maskBuilder.setName(mask.getName()); + for(String param: mask.getParameters()) { + maskBuilder.addMaskParameters(param); + } + for(TypeDescription column: mask.getColumns()) { + maskBuilder.addColumns(column.getId()); + } + encrypt.addMask(maskBuilder); + } + for(WriterEncryptionKey key: keys.values()) { + encrypt.addKey(writeEncryptionKey(key)); + } + for(WriterEncryptionVariant variant: encryption) { + encrypt.addVariants(writeEncryptionVariant(variant)); + } + encrypt.setKeyProvider(OrcProto.KeyProviderKind.valueOf( + keyProvider.getKind().getValue())); + return encrypt; + } + private long writeFooter() throws IOException { writeMetadata(); OrcProto.Footer.Builder builder = OrcProto.Footer.newBuilder(); @@ -492,12 +617,15 @@ private long writeFooter() throws IOException { builder.addStripes(stripe); } // add the column statistics - writeFileStatistics(builder, treeWriter); + treeWriter.writeFileStatistics(); // add all of the user metadata for(Map.Entry entry: userMetadata.entrySet()) { builder.addMetadata(OrcProto.UserMetadataItem.newBuilder() .setName(entry.getKey()).setValue(entry.getValue())); } + if (encryption.length > 0) { + builder.setEncryption(writeEncryptionFooter()); + } builder.setWriter(OrcFile.WriterImplementation.ORC_JAVA.getId()); physicalWriter.writeFileFooter(builder); return writePostScript(); @@ -515,6 +643,11 @@ public void addUserMetadata(String name, ByteBuffer value) { @Override public void addRowBatch(VectorizedRowBatch batch) throws IOException { + // If this is the first set of rows in this stripe, tell the tree writers + // to prepare the stripe. + if (batch.size != 0 && rowsInStripe == 0) { + treeWriter.prepareStripe(stripes.size() + 1); + } if (buildIndex) { // Batch the writes up to the rowIndexStride so that we can get the // right size indexes. @@ -601,6 +734,10 @@ public void appendStripe(byte[] stripe, int offset, int length, checkArgument(stripeStatistics != null, "Stripe statistics must not be null"); + // If we have buffered rows, flush them + if (rowsInStripe > 0) { + flushStripe(); + } rowsInStripe = stripeInfo.getNumberOfRows(); // update stripe information OrcProto.StripeInformation.Builder dirEntry = OrcProto.StripeInformation @@ -614,13 +751,13 @@ public void appendStripe(byte[] stripe, int offset, int length, // since we have already written the stripe, just update stripe statistics treeWriter.updateFileStatistics(stripeStatistics); - fileMetadata.addStripeStats(stripeStatistics); stripes.add(dirEntry.build()); // reset it after writing the stripe rowCount += rowsInStripe; rowsInStripe = 0; + needKeyFlush = encryption.length > 0; } @Override @@ -633,18 +770,17 @@ public void appendUserMetadata(List userMetadata) { } @Override - public ColumnStatistics[] getStatistics() - throws IOException { - // Generate the stats - OrcProto.Footer.Builder builder = OrcProto.Footer.newBuilder(); - - // add the column statistics - writeFileStatistics(builder, treeWriter); - return ReaderImpl.deserializeStats(schema, builder.getStatisticsList()); + public ColumnStatistics[] getStatistics() { + // get the column statistics + final ColumnStatistics[] result = + new ColumnStatistics[schema.getMaximumId() + 1]; + // Get the file statistics, preferring the encrypted one. + treeWriter.getCurrentStatistics(result); + return result; } public CompressionCodec getCompressionCodec() { - return physicalWriter.getCompressionCodec(); + return unencryptedOptions.getCodec(); } private static boolean hasTimestamp(TypeDescription schema) { @@ -661,4 +797,79 @@ private static boolean hasTimestamp(TypeDescription schema) { } return false; } + + WriterEncryptionKey getKey(String keyName, + HadoopShims.KeyProvider provider) throws IOException { + WriterEncryptionKey result = keys.get(keyName); + if (result == null) { + result = new WriterEncryptionKey(provider.getCurrentKeyVersion(keyName)); + keys.put(keyName, result); + } + return result; + } + + MaskDescriptionImpl getMask(OrcFile.EncryptionOption opt) { + MaskDescriptionImpl result = new MaskDescriptionImpl(opt.getMask(), + opt.getMaskParameters()); + // if it is already there, get the earlier object + if (!maskDescriptions.add(result)) { + result = maskDescriptions.tailSet(result).first(); + } + return result; + } + + /** + * Iterate through the encryption options given by the user and set up + * our data structures. + * @param provider the KeyProvider to use to generate keys + * @param options the options from the user + */ + WriterEncryptionVariant[] setupEncryption(HadoopShims.KeyProvider provider, + List options + ) throws IOException { + keyProvider = provider != null ? provider : + SHIMS.getKeyProvider(conf, new SecureRandom()); + if (keyProvider == null) { + throw new IllegalArgumentException("Encryption requires a KeyProvider."); + } + // fill out the primary encryption keys + int variantCount = 0; + for(OrcFile.EncryptionOption option: options) { + MaskDescriptionImpl mask = getMask(option); + for(TypeDescription col: schema.findSubtypes(option.getColumnNames())) { + mask.addColumn(col); + } + if (option.getKeyName() != null) { + WriterEncryptionKey key = getKey(option.getKeyName(), keyProvider); + HadoopShims.KeyMetadata metadata = key.getMetadata(); + for(TypeDescription rootType: schema.findSubtypes(option.getColumnNames())) { + WriterEncryptionVariant variant = new WriterEncryptionVariant(key, + rootType, keyProvider.createLocalKey(metadata)); + key.addRoot(variant); + variantCount += 1; + } + } + } + // Now that we have de-duped the keys and maskDescriptions, make the arrays + int nextId = 0; + for (MaskDescriptionImpl mask: maskDescriptions) { + mask.setId(nextId++); + for(TypeDescription column: mask.getColumns()) { + this.columnMaskDescriptions[column.getId()] = mask; + } + } + nextId = 0; + int nextVariantId = 0; + WriterEncryptionVariant[] result = new WriterEncryptionVariant[variantCount]; + for(WriterEncryptionKey key: keys.values()) { + key.setId(nextId++); + key.sortRoots(); + for(WriterEncryptionVariant variant: key.getEncryptionRoots()) { + result[nextVariantId] = variant; + columnEncryption[variant.getRoot().getId()] = variant; + variant.setId(nextVariantId++); + } + } + return result; + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java index 14669c9ebc..0567d43a5d 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java @@ -23,27 +23,30 @@ import org.apache.orc.BinaryColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class BinaryTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final IntegerWriter length; private boolean isDirectV2 = true; - public BinaryTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public BinaryTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); - this.length = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + this.length = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -104,10 +107,8 @@ public void writeBatch(ColumnVector vector, int offset, @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -141,4 +142,11 @@ public void flushStreams() throws IOException { } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + stream.changeIv(updater); + length.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java index 744aaefa0c..5329cf90aa 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java @@ -24,21 +24,23 @@ import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.BitFieldWriter; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class BooleanTreeWriter extends TreeWriterBase { private final BitFieldWriter writer; - public BooleanTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - PositionedOutputStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public BooleanTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + PositionedOutputStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.writer = new BitFieldWriter(out, 1); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -70,10 +72,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -101,4 +101,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java index a8dc0599bc..a3e1d456df 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java @@ -23,21 +23,22 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthByteWriter; +import org.apache.orc.impl.StreamName; import java.io.IOException; public class ByteTreeWriter extends TreeWriterBase { private final RunLengthByteWriter writer; - public ByteTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.writer = new RunLengthByteWriter(writer.createStream(id, - OrcProto.Stream.Kind.DATA)); + public ByteTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.writer = new RunLengthByteWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption))); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -80,10 +81,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -111,4 +110,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java index 14e3c26f22..83a72e9298 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java @@ -34,11 +34,10 @@ public class CharTreeWriter extends StringBaseTreeWriter { private final int maxLength; private final byte[] padding; - CharTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + CharTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); maxLength = schema.getMaxLength(); // utf-8 is currently 4 bytes long, but it could be upto 6 padding = new byte[6*maxLength]; diff --git a/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java index 209dd0e36b..bc81d456a0 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class DateTreeWriter extends TreeWriterBase { private final IntegerWriter writer; private final boolean isDirectV2; - public DateTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public DateTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); this.writer = createIntegerWriter(out, true, isDirectV2, writer); if (rowIndexPosition != null) { @@ -84,10 +85,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -127,4 +126,9 @@ public void flushStreams() throws IOException { writer.flush(); } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java index 020d8ff4f8..4b3cfdd63f 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java @@ -25,11 +25,14 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthIntegerWriterV2; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; /** * Writer for short decimals in ORCv2. @@ -38,12 +41,12 @@ public class Decimal64TreeWriter extends TreeWriterBase { private final RunLengthIntegerWriterV2 valueWriter; private final int scale; - public Decimal64TreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream stream = writer.createStream(id, OrcProto.Stream.Kind.DATA); + public Decimal64TreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA)); // Use RLEv2 until we have the new RLEv3. valueWriter = new RunLengthIntegerWriterV2(stream, true, true); scale = schema.getScale(); @@ -121,10 +124,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -151,4 +152,10 @@ public void flushStreams() throws IOException { super.flushStreams(); valueWriter.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + valueWriter.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java index 822042eef5..be2b2bf3bc 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java @@ -26,12 +26,15 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class DecimalTreeWriter extends TreeWriterBase { private final PositionedOutputStream valueStream; @@ -44,17 +47,18 @@ public class DecimalTreeWriter extends TreeWriterBase { private final IntegerWriter scaleStream; private final boolean isDirectV2; - public DecimalTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public DecimalTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - valueStream = writer.createStream(id, OrcProto.Stream.Kind.DATA); + valueStream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); scratchLongs = new long[HiveDecimal.SCRATCH_LONGS_LEN]; scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; - this.scaleStream = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.SECONDARY), true, isDirectV2, writer); + this.scaleStream = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.SECONDARY, encryption)), + true, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -161,10 +165,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -195,4 +197,12 @@ public void flushStreams() throws IOException { valueStream.flush(); scaleStream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + valueStream.changeIv(updater); + scaleStream.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java index 84218ca6fb..17f0f73317 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class DoubleTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final SerializationUtils utils; - public DoubleTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public DoubleTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.utils = new SerializationUtils(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -83,10 +84,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); stream.flush(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -115,4 +114,10 @@ public void flushStreams() throws IOException { super.flushStreams(); stream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + stream.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java new file mode 100644 index 0000000000..981f75a9c5 --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; +import org.apache.orc.DataMask; +import org.apache.orc.OrcProto; +import org.apache.orc.TypeDescription; + +import java.io.IOException; + +/** + * TreeWriter that handles column encryption. + * We create a TreeWriter for each of the alternatives with an WriterContext + * that creates encrypted streams. + */ +public class EncryptionTreeWriter implements TreeWriter { + // the different writers + private final TreeWriter[] childrenWriters; + private final DataMask[] masks; + // a column vector that we use to apply the masks + private final VectorizedRowBatch scratch; + + EncryptionTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext context) throws IOException { + scratch = schema.createRowBatch(); + childrenWriters = new TreeWriterBase[2]; + masks = new DataMask[childrenWriters.length]; + + // no mask, encrypted data + masks[0] = null; + childrenWriters[0] = Factory.createSubtree(schema, encryption, context); + + // masked unencrypted + masks[1] = context.getUnencryptedMask(schema.getId()); + childrenWriters[1] = Factory.createSubtree(schema, null, context); + } + + @Override + public void writeRootBatch(VectorizedRowBatch batch, int offset, + int length) throws IOException { + scratch.ensureSize(length); + for(int alt=0; alt < childrenWriters.length; ++alt) { + // if there is a mask, apply it to each column + if (masks[alt] != null) { + for(int col=0; col < scratch.cols.length; ++col) { + masks[alt].maskData(batch.cols[col], scratch.cols[col], offset, + length); + } + childrenWriters[alt].writeRootBatch(scratch, offset, length); + } else { + childrenWriters[alt].writeRootBatch(batch, offset, length); + } + } + } + + @Override + public void writeBatch(ColumnVector vector, int offset, + int length) throws IOException { + for(int alt=0; alt < childrenWriters.length; ++alt) { + // if there is a mask, apply it to each column + if (masks[alt] != null) { + masks[alt].maskData(vector, scratch.cols[0], offset, length); + childrenWriters[alt].writeBatch(scratch.cols[0], offset, length); + } else { + childrenWriters[alt].writeBatch(vector, offset, length); + } + } + } + + @Override + public void createRowIndexEntry() throws IOException { + for(TreeWriter child: childrenWriters) { + child.createRowIndexEntry(); + } + } + + @Override + public void flushStreams() throws IOException { + for(TreeWriter child: childrenWriters) { + child.flushStreams(); + } + } + + @Override + public void writeStripe(int requiredIndexEntries) throws IOException { + for(TreeWriter child: childrenWriters) { + child.writeStripe(requiredIndexEntries); + } + } + + @Override + public void updateFileStatistics(OrcProto.StripeStatistics stats) { + for(TreeWriter child: childrenWriters) { + child.updateFileStatistics(stats); + } + } + + @Override + public long estimateMemory() { + long result = 0; + for (TreeWriter writer : childrenWriters) { + result += writer.estimateMemory(); + } + return result; + } + + @Override + public long getRawDataSize() { + // return the size of the encrypted data + return childrenWriters[0].getRawDataSize(); + } + + @Override + public void prepareStripe(int stripeId) { + for (TreeWriter writer : childrenWriters) { + writer.prepareStripe(stripeId); + } + } + + @Override + public void writeFileStatistics() throws IOException { + for (TreeWriter child : childrenWriters) { + child.writeFileStatistics(); + } + } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + childrenWriters[0].getCurrentStatistics(output); + } +} diff --git a/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java index e4198a21ac..bc3a15b023 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class FloatTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final SerializationUtils utils; - public FloatTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public FloatTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.utils = new SerializationUtils(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -84,10 +85,8 @@ public void writeBatch(ColumnVector vector, int offset, @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -115,4 +114,10 @@ public void flushStreams() throws IOException { super.flushStreams(); stream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + stream.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java index dc0eaad1b9..7f8f21a0bd 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -34,13 +36,12 @@ public class IntegerTreeWriter extends TreeWriterBase { private boolean isDirectV2 = true; private final boolean isLong; - public IntegerTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public IntegerTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); this.writer = createIntegerWriter(out, true, isDirectV2, writer); if (rowIndexPosition != null) { @@ -97,10 +98,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -129,4 +128,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java index c6068cdec0..3cd3ed11b1 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java @@ -20,10 +20,13 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -32,15 +35,15 @@ public class ListTreeWriter extends TreeWriterBase { private final boolean isDirectV2; private final TreeWriter childWriter; - ListTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + ListTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - childWriter = Factory.create(schema.getChildren().get(0), writer, true); - lengths = createIntegerWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + childWriter = Factory.create(schema.getChildren().get(0), encryption, writer); + lengths = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -119,11 +122,9 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); - childWriter.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); + childWriter.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -153,9 +154,9 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - childWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + childWriter.writeFileStatistics(); } @Override @@ -164,4 +165,17 @@ public void flushStreams() throws IOException { lengths.flush(); childWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + childWriter.getCurrentStatistics(output); + } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + lengths.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + childWriter.prepareStripe(stripeId); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java index 91e56578ae..02191adc7b 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java @@ -19,10 +19,13 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.util.List; @@ -33,17 +36,17 @@ public class MapTreeWriter extends TreeWriterBase { private final TreeWriter keyWriter; private final TreeWriter valueWriter; - MapTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + MapTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); List children = schema.getChildren(); - keyWriter = Factory.create(children.get(0), writer, true); - valueWriter = Factory.create(children.get(1), writer, true); - lengths = createIntegerWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + keyWriter = Factory.create(children.get(0), encryption, writer); + valueWriter = Factory.create(children.get(1), encryption, writer); + lengths = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -128,12 +131,10 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); - keyWriter.writeStripe(builder, stats, requiredIndexEntries); - valueWriter.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); + keyWriter.writeStripe(requiredIndexEntries); + valueWriter.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -164,10 +165,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - keyWriter.writeFileStatistics(footer); - valueWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + keyWriter.writeFileStatistics(); + valueWriter.writeFileStatistics(); } @Override @@ -177,4 +178,19 @@ public void flushStreams() throws IOException { keyWriter.flushStreams(); valueWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + keyWriter.getCurrentStatistics(output); + valueWriter.getCurrentStatistics(output); + } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + lengths.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + keyWriter.prepareStripe(stripeId); + valueWriter.prepareStripe(stripeId); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java index e7d32593d3..c3f56a6d03 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java @@ -25,16 +25,19 @@ import org.apache.orc.OrcProto; import org.apache.orc.StringColumnStatistics; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.DynamicIntArray; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import org.apache.orc.impl.StringRedBlackTree; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; public abstract class StringBaseTreeWriter extends TreeWriterBase { private static final int INITIAL_DICTIONARY_SIZE = 4096; @@ -57,17 +60,18 @@ public abstract class StringBaseTreeWriter extends TreeWriterBase { private boolean doneDictionaryCheck; private final boolean strideDictionaryCheck; - StringBaseTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + StringBaseTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - directStreamOutput = writer.createStream(id, OrcProto.Stream.Kind.DATA); - stringOutput = writer.createStream(id, - OrcProto.Stream.Kind.DICTIONARY_DATA); - lengthOutput = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + directStreamOutput = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); + stringOutput = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DICTIONARY_DATA, encryption)); + lengthOutput = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); rowOutput = createIntegerWriter(directStreamOutput, false, isDirectV2, writer); if (rowIndexPosition != null) { @@ -76,7 +80,7 @@ public abstract class StringBaseTreeWriter extends TreeWriterBase { rowIndexValueCount.add(0L); buildIndex = writer.buildIndex(); Configuration conf = writer.getConfiguration(); - dictionaryKeySizeThreshold = writer.getDictionaryKeySizeThreshold(columnId); + dictionaryKeySizeThreshold = writer.getDictionaryKeySizeThreshold(id); strideDictionaryCheck = OrcConf.ROW_INDEX_STRIDE_DICTIONARY_CHECK.getBoolean(conf); if (dictionaryKeySizeThreshold <= 0.0) { @@ -99,9 +103,10 @@ private void checkDictionaryEncoding() { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { + public void writeStripe(int requiredIndexEntries) throws IOException { + // if rows in stripe is less than dictionaryCheckAfterRows, dictionary + // checking would not have happened. So do it again here. + checkDictionaryEncoding(); checkDictionaryEncoding(); if (!useDictionaryEncoding) { @@ -110,8 +115,7 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, // we need to build the rowindex before calling super, since it // writes it out. - super.writeStripe(builder, stats, requiredIndexEntries); - + super.writeStripe(requiredIndexEntries); // reset all of the fields to be ready for the next stripe. dictionary.clear(); savedRowIndex.clear(); @@ -297,7 +301,16 @@ public void flushStreams() throws IOException { directStreamOutput.flush(); lengthOutput.flush(); } + } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + stringOutput.changeIv(updater); + lengthOutput.changeIv(updater); + rowOutput.changeIv(updater); + directStreamOutput.changeIv(updater); } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java index ab6f38f9c6..ed1de950fd 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java @@ -26,11 +26,10 @@ import java.nio.charset.StandardCharsets; public class StringTreeWriter extends StringBaseTreeWriter { - StringTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + StringTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); } @Override diff --git a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java index ee0b0c041a..a78b387533 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -30,15 +31,14 @@ public class StructTreeWriter extends TreeWriterBase { final TreeWriter[] childrenWriters; - public StructTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public StructTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); List children = schema.getChildren(); - childrenWriters = new TreeWriterBase[children.size()]; + childrenWriters = new TreeWriter[children.size()]; for (int i = 0; i < childrenWriters.length; ++i) { - childrenWriters[i] = Factory.create(children.get(i), writer, true); + childrenWriters[i] = Factory.create(children.get(i), encryption, writer); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -108,12 +108,10 @@ public void createRowIndexEntry() throws IOException { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); for (TreeWriter child : childrenWriters) { - child.writeStripe(builder, stats, requiredIndexEntries); + child.writeStripe(requiredIndexEntries); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -147,10 +145,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -160,6 +158,21 @@ public void flushStreams() throws IOException { for (TreeWriter child : childrenWriters) { child.flushStreams(); } + } + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for (TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } + } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + for (TreeWriter child: childrenWriters) { + child.prepareStripe(stripeId); + } } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java index 0f30d07757..3ba2dbeb9c 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java @@ -23,15 +23,18 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.TimeZone; +import java.util.function.Consumer; public class TimestampTreeWriter extends TreeWriterBase { public static final int MILLIS_PER_SECOND = 1000; @@ -45,16 +48,17 @@ public class TimestampTreeWriter extends TreeWriterBase { private final long baseEpochSecsLocalTz; private final long baseEpochSecsUTC; - public TimestampTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public TimestampTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - this.seconds = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.DATA), true, isDirectV2, writer); - this.nanos = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.SECONDARY), false, isDirectV2, writer); + this.seconds = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)), + true, isDirectV2, writer); + this.nanos = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.SECONDARY, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -152,10 +156,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -202,4 +204,12 @@ public void flushStreams() throws IOException { seconds.flush(); nanos.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + seconds.changeIv(updater); + nanos.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java index bfa403eeff..680cf8cebf 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -45,6 +46,12 @@ public interface TreeWriter { */ long getRawDataSize(); + /** + * Set up for the next stripe. + * @param stripeId the next stripe id + */ + void prepareStripe(int stripeId); + /** * Write a VectorizedRowBath to the file. This is called by the WriterImplV2 * at the top level. @@ -78,17 +85,11 @@ void writeBatch(ColumnVector vector, int offset, /** * Write the stripe out to the file. - * @param stripeFooter the stripe footer that contains the information about the - * layout of the stripe. The TreeWriterBase is required to update - * the footer with its information. - * @param stats the stripe statistics information * @param requiredIndexEntries the number of index entries that are * required. this is to check to make sure the * row index is well formed. */ - void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException; + void writeStripe(int requiredIndexEntries) throws IOException; /** * During a stripe append, we need to update the file statistics. @@ -97,77 +98,93 @@ void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, void updateFileStatistics(OrcProto.StripeStatistics stripeStatistics); /** - * Add the file statistics to the file footer. - * @param footer the file footer builder + * Write the FileStatistics for each column in each encryption variant. + */ + void writeFileStatistics() throws IOException; + + /** + * Get the current file statistics for each column. If a column is encrypted, + * the encrypted variant statistics are used. + * @param output an array that is filled in with the results */ - void writeFileStatistics(OrcProto.Footer.Builder footer); + void getCurrentStatistics(ColumnStatistics[] output); class Factory { + /** + * Create a new tree writer for the given types and insert encryption if + * required. + * @param schema the type to build a writer for + * @param encryption the encryption status + * @param streamFactory the writer context + * @return a new tree writer + */ public static TreeWriter create(TypeDescription schema, - WriterContext streamFactory, - boolean nullable) throws IOException { + WriterEncryptionVariant encryption, + WriterContext streamFactory) throws IOException { + if (encryption == null) { + // If we are the root of an encryption variant, create a special writer. + encryption = streamFactory.getEncryption(schema.getId()); + if (encryption != null) { + return new EncryptionTreeWriter(schema, encryption, streamFactory); + } + } + return createSubtree(schema, encryption, streamFactory); + } + + /** + * Create a subtree without inserting encryption nodes + * @param schema the schema to create + * @param encryption the encryption variant + * @param streamFactory the writer context + * @return a new tree writer + */ + static TreeWriter createSubtree(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext streamFactory) throws IOException { OrcFile.Version version = streamFactory.getVersion(); switch (schema.getCategory()) { - case BOOLEAN: - return new BooleanTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case BYTE: - return new ByteTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case SHORT: - case INT: - case LONG: - return new IntegerTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case FLOAT: - return new FloatTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DOUBLE: - return new DoubleTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case STRING: - return new StringTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case CHAR: - return new CharTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case VARCHAR: - return new VarcharTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case BINARY: - return new BinaryTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case TIMESTAMP: - return new TimestampTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DATE: - return new DateTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DECIMAL: - if (version == OrcFile.Version.UNSTABLE_PRE_2_0 && - schema.getPrecision() <= TypeDescription.MAX_DECIMAL64_PRECISION) { - return new Decimal64TreeWriter(schema.getId(), - schema, streamFactory, nullable); - } - return new DecimalTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case STRUCT: - return new StructTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case MAP: - return new MapTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case LIST: - return new ListTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case UNION: - return new UnionTreeWriter(schema.getId(), - schema, streamFactory, nullable); - default: - throw new IllegalArgumentException("Bad category: " + - schema.getCategory()); + case BOOLEAN: + return new BooleanTreeWriter(schema, encryption, streamFactory); + case BYTE: + return new ByteTreeWriter(schema, encryption, streamFactory); + case SHORT: + case INT: + case LONG: + return new IntegerTreeWriter(schema, encryption, streamFactory); + case FLOAT: + return new FloatTreeWriter(schema, encryption, streamFactory); + case DOUBLE: + return new DoubleTreeWriter(schema, encryption, streamFactory); + case STRING: + return new StringTreeWriter(schema, encryption, streamFactory); + case CHAR: + return new CharTreeWriter(schema, encryption, streamFactory); + case VARCHAR: + return new VarcharTreeWriter(schema, encryption, streamFactory); + case BINARY: + return new BinaryTreeWriter(schema, encryption, streamFactory); + case TIMESTAMP: + return new TimestampTreeWriter(schema, encryption, streamFactory); + case DATE: + return new DateTreeWriter(schema, encryption, streamFactory); + case DECIMAL: + if (version == OrcFile.Version.UNSTABLE_PRE_2_0 && + schema.getPrecision() <= TypeDescription.MAX_DECIMAL64_PRECISION) { + return new Decimal64TreeWriter(schema, encryption, streamFactory); + } + return new DecimalTreeWriter(schema, encryption, streamFactory); + case STRUCT: + return new StructTreeWriter(schema, encryption, streamFactory); + case MAP: + return new MapTreeWriter(schema, encryption, streamFactory); + case LIST: + return new ListTreeWriter(schema, encryption, streamFactory); + case UNION: + return new UnionTreeWriter(schema, encryption, streamFactory); + default: + throw new IllegalArgumentException("Bad category: " + + schema.getCategory()); } } - } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java index 7934b21755..17a2a5fbb4 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java @@ -23,11 +23,13 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.BitFieldWriter; import org.apache.orc.impl.ColumnStatisticsImpl; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; @@ -48,6 +50,8 @@ public abstract class TreeWriterBase implements TreeWriter { protected final int id; protected final BitFieldWriter isPresent; + protected final TypeDescription schema; + protected final WriterEncryptionVariant encryption; private final boolean isCompressed; protected final ColumnStatisticsImpl indexStatistics; protected final ColumnStatisticsImpl stripeColStatistics; @@ -63,37 +67,31 @@ public abstract class TreeWriterBase implements TreeWriter { protected final OrcProto.BloomFilter.Builder bloomFilterEntry; private boolean foundNulls; private OutStream isPresentOutStream; - private final WriterContext streamFactory; - private final TypeDescription schema; + protected final WriterContext context; /** * Create a tree writer. - * @param columnId the column id of the column to write * @param schema the row schema - * @param streamFactory limited access to the Writer's data. - * @param nullable can the value be null? + * @param encryption the encryption variant or null if it is unencrypted + * @param context limited access to the Writer's data. */ - TreeWriterBase(int columnId, - TypeDescription schema, - WriterContext streamFactory, - boolean nullable) throws IOException { + TreeWriterBase(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext context) throws IOException { this.schema = schema; - this.streamFactory = streamFactory; - this.isCompressed = streamFactory.isCompressed(); - this.id = columnId; - if (nullable) { - isPresentOutStream = streamFactory.createStream(id, - OrcProto.Stream.Kind.PRESENT); - isPresent = new BitFieldWriter(isPresentOutStream, 1); - } else { - isPresent = null; - } + this.encryption = encryption; + this.context = context; + this.isCompressed = context.isCompressed(); + this.id = schema.getId(); + isPresentOutStream = context.createStream(new StreamName(id, + OrcProto.Stream.Kind.PRESENT, encryption)); + isPresent = new BitFieldWriter(isPresentOutStream, 1); this.foundNulls = false; - createBloomFilter = streamFactory.getBloomFilterColumns()[columnId]; + createBloomFilter = context.getBloomFilterColumns()[id]; indexStatistics = ColumnStatisticsImpl.create(schema); stripeColStatistics = ColumnStatisticsImpl.create(schema); fileStatistics = ColumnStatisticsImpl.create(schema); - if (streamFactory.buildIndex()) { + if (context.buildIndex()) { rowIndex = OrcProto.RowIndex.newBuilder(); rowIndexEntry = OrcProto.RowIndexEntry.newBuilder(); rowIndexPosition = new RowIndexPositionRecorder(rowIndexEntry); @@ -104,16 +102,16 @@ public abstract class TreeWriterBase implements TreeWriter { } if (createBloomFilter) { bloomFilterEntry = OrcProto.BloomFilter.newBuilder(); - if (streamFactory.getBloomFilterVersion() == OrcFile.BloomFilterVersion.ORIGINAL) { - bloomFilter = new BloomFilter(streamFactory.getRowIndexStride(), - streamFactory.getBloomFilterFPP()); + if (context.getBloomFilterVersion() == OrcFile.BloomFilterVersion.ORIGINAL) { + bloomFilter = new BloomFilter(context.getRowIndexStride(), + context.getBloomFilterFPP()); bloomFilterIndex = OrcProto.BloomFilterIndex.newBuilder(); } else { bloomFilter = null; bloomFilterIndex = null; } - bloomFilterUtf8 = new BloomFilterUtf8(streamFactory.getRowIndexStride(), - streamFactory.getBloomFilterFPP()); + bloomFilterUtf8 = new BloomFilterUtf8(context.getRowIndexStride(), + context.getBloomFilterFPP()); bloomFilterIndexUtf8 = OrcProto.BloomFilterIndex.newBuilder(); } else { bloomFilterEntry = null; @@ -232,17 +230,21 @@ private void removeIsPresentPositions() { } @Override - public void flushStreams() throws IOException { + public void prepareStripe(int stripeId) { + if (isPresent != null) { + isPresent.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } + } + @Override + public void flushStreams() throws IOException { if (isPresent != null) { isPresent.flush(); } - } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, int requiredIndexEntries) throws IOException { + public void writeStripe(int requiredIndexEntries) throws IOException { // if no nulls are found in a stream, then suppress the stream if (isPresent != null && !foundNulls) { @@ -252,47 +254,47 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, if (rowIndex != null) { removeIsPresentPositions(); } - } /* Update byte count */ - final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id); + final long byteCount = context.getPhysicalWriter().getFileBytes(id, encryption); stripeColStatistics.updateByteCount(byteCount); // merge stripe-level column statistics to file statistics and write it to // stripe statistics fileStatistics.merge(stripeColStatistics); - stats.addColStats(stripeColStatistics.serialize()); + context.writeStatistics( + new StreamName(id, OrcProto.Stream.Kind.STRIPE_STATISTICS, encryption), + stripeColStatistics.serialize()); stripeColStatistics.reset(); // reset the flag for next stripe foundNulls = false; - builder.addColumns(getEncoding()); + context.setEncoding(id, encryption, getEncoding().build()); if (rowIndex != null) { if (rowIndex.getEntryCount() != requiredIndexEntries) { throw new IllegalArgumentException("Column has wrong number of " + "index entries found: " + rowIndex.getEntryCount() + " expected: " + requiredIndexEntries); } - streamFactory.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex); + context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex); rowIndex.clear(); rowIndexEntry.clear(); } // write the bloom filter to out stream if (bloomFilterIndex != null) { - streamFactory.writeBloomFilter(new StreamName(id, + context.writeBloomFilter(new StreamName(id, OrcProto.Stream.Kind.BLOOM_FILTER), bloomFilterIndex); bloomFilterIndex.clear(); } // write the bloom filter to out stream if (bloomFilterIndexUtf8 != null) { - streamFactory.writeBloomFilter(new StreamName(id, + context.writeBloomFilter(new StreamName(id, OrcProto.Stream.Kind.BLOOM_FILTER_UTF8), bloomFilterIndexUtf8); bloomFilterIndexUtf8.clear(); } - } /** @@ -369,8 +371,10 @@ public long estimateMemory() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - footer.addStatistics(fileStatistics.serialize()); + public void writeFileStatistics() throws IOException { + context.writeStatistics(new StreamName(id, + OrcProto.Stream.Kind.FILE_STATISTICS, encryption), + fileStatistics.serialize()); } static class RowIndexPositionRecorder implements PositionRecorder { @@ -385,4 +389,9 @@ public void addPosition(long position) { builder.addPositions(position); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + output[id] = fileStatistics; + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java index 54a9a3a6dc..df4dfef123 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java @@ -20,10 +20,13 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthByteWriter; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.util.List; @@ -32,19 +35,18 @@ public class UnionTreeWriter extends TreeWriterBase { private final RunLengthByteWriter tags; private final TreeWriter[] childrenWriters; - UnionTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + UnionTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); List children = schema.getChildren(); childrenWriters = new TreeWriterBase[children.size()]; for (int i = 0; i < childrenWriters.length; ++i) { - childrenWriters[i] = Factory.create(children.get(i), writer, true); + childrenWriters[i] = Factory.create(children.get(i), encryption, writer); } tags = - new RunLengthByteWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.DATA)); + new RunLengthByteWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption))); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -120,12 +122,10 @@ public void createRowIndexEntry() throws IOException { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); for (TreeWriter child : childrenWriters) { - child.writeStripe(builder, stats, requiredIndexEntries); + child.writeStripe(requiredIndexEntries); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -165,10 +165,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -180,4 +180,21 @@ public void flushStreams() throws IOException { child.flushStreams(); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for(TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } + } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + tags.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + for (TreeWriter child: childrenWriters) { + child.prepareStripe(stripeId); + } + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java index b08ef437cf..29a6ab75d4 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java @@ -32,11 +32,10 @@ public class VarcharTreeWriter extends StringBaseTreeWriter { private final int maxLength; - VarcharTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + VarcharTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); maxLength = schema.getMaxLength(); } diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java index 9ef3ddaf2e..73542ad976 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java @@ -19,9 +19,12 @@ package org.apache.orc.impl.writer; import org.apache.hadoop.conf.Configuration; +import org.apache.orc.CompressionCodec; +import org.apache.orc.DataMask; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; +import org.apache.orc.TypeDescription; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.StreamName; @@ -30,79 +33,115 @@ public interface WriterContext { /** - * Create a stream to store part of a column. - * @param column the column id for the stream - * @param kind the kind of stream - * @return The output outStream that the section needs to be written to. - */ - OutStream createStream(int column, - OrcProto.Stream.Kind kind - ) throws IOException; - - /** - * Get the stride rate of the row index. - */ - int getRowIndexStride(); - - /** - * Should be building the row index. - * @return true if we are building the index - */ - boolean buildIndex(); - - /** - * Is the ORC file compressed? - * @return are the streams compressed - */ - boolean isCompressed(); - - /** - * Get the encoding strategy to use. - * @return encoding strategy - */ - OrcFile.EncodingStrategy getEncodingStrategy(); - - /** - * Get the bloom filter columns - * @return bloom filter columns - */ - boolean[] getBloomFilterColumns(); - - /** - * Get bloom filter false positive percentage. - * @return fpp - */ - double getBloomFilterFPP(); - - /** - * Get the writer's configuration. - * @return configuration - */ - Configuration getConfiguration(); - - /** - * Get the version of the file to write. - */ - OrcFile.Version getVersion(); - - /** - * Get the PhysicalWriter. - * - * @return the file's physical writer. - */ - PhysicalWriter getPhysicalWriter(); - - - OrcFile.BloomFilterVersion getBloomFilterVersion(); - - void writeIndex(StreamName name, - OrcProto.RowIndex.Builder index) throws IOException; - - void writeBloomFilter(StreamName name, - OrcProto.BloomFilterIndex.Builder bloom - ) throws IOException; - - boolean getUseUTCTimestamp(); - - double getDictionaryKeySizeThreshold(int column); + * Create a stream to store part of a column. + * @param name the name of the stream + * @return The output outStream that the section needs to be written to. + */ + OutStream createStream(StreamName name) throws IOException; + + /** + * Get the stride rate of the row index. + */ + int getRowIndexStride(); + + /** + * Should be building the row index. + * @return true if we are building the index + */ + boolean buildIndex(); + + /** + * Is the ORC file compressed? + * @return are the streams compressed + */ + boolean isCompressed(); + + /** + * Get the encoding strategy to use. + * @return encoding strategy + */ + OrcFile.EncodingStrategy getEncodingStrategy(); + + /** + * Get the bloom filter columns + * @return bloom filter columns + */ + boolean[] getBloomFilterColumns(); + + /** + * Get bloom filter false positive percentage. + * @return fpp + */ + double getBloomFilterFPP(); + + /** + * Get the writer's configuration. + * @return configuration + */ + Configuration getConfiguration(); + + /** + * Get the version of the file to write. + */ + OrcFile.Version getVersion(); + + OrcFile.BloomFilterVersion getBloomFilterVersion(); + + void writeIndex(StreamName name, + OrcProto.RowIndex.Builder index) throws IOException; + + void writeBloomFilter(StreamName name, + OrcProto.BloomFilterIndex.Builder bloom + ) throws IOException; + + /** + * Get the mask for the unencrypted variant. + * @param columnId the column id + * @return the mask to apply to the unencrypted data or null if there is none + */ + DataMask getUnencryptedMask(int columnId); + + /** + * Get the encryption for the given column. + * @param columnId the root column id + * @return the column encryption or null if it isn't encrypted + */ + WriterEncryptionVariant getEncryption(int columnId); + + /** + * Get the PhysicalWriter. + * @return the file's physical writer. + */ + PhysicalWriter getPhysicalWriter(); + + /** + * Set the encoding for the current stripe. + * @param column the column identifier + * @param variant the encryption variant + * @param encoding the encoding for this stripe + */ + void setEncoding(int column, WriterEncryptionVariant variant, + OrcProto.ColumnEncoding encoding); + + /** + * Set the column statistics for the stripe or file. + * @param name the name of the statistics stream + * @param stats the statistics for this column in this stripe + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException; + + /** + * Should the writer use UTC as the timezone? + */ + boolean getUseUTCTimestamp(); + + /** + * Get the dictionary key size threshold. + * @param columnId the column id + * @return the minimum ratio for using a dictionary + */ + double getDictionaryKeySizeThreshold(int columnId); + } diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java new file mode 100644 index 0000000000..40606aa98c --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.EncryptionKey; +import org.apache.orc.EncryptionVariant; +import org.apache.orc.impl.HadoopShims; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class WriterEncryptionKey implements EncryptionKey { + private final HadoopShims.KeyMetadata metadata; + private final List roots = new ArrayList<>(); + private int id; + + public WriterEncryptionKey(HadoopShims.KeyMetadata key) { + this.metadata = key; + } + + public void addRoot(WriterEncryptionVariant root) { + roots.add(root); + } + + public HadoopShims.KeyMetadata getMetadata() { + return metadata; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public String getKeyName() { + return metadata.getKeyName(); + } + + @Override + public int getKeyVersion() { + return metadata.getVersion(); + } + + public EncryptionAlgorithm getAlgorithm() { + return metadata.getAlgorithm(); + } + + @Override + public WriterEncryptionVariant[] getEncryptionRoots() { + return roots.toArray(new WriterEncryptionVariant[roots.size()]); + } + + public int getId() { + return id; + } + + public void sortRoots() { + Collections.sort(roots); + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object other) { + if (other == null || getClass() != other.getClass()) { + return false; + } + return compareTo((EncryptionKey) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionKey other) { + int result = getKeyName().compareTo(other.getKeyName()); + if (result == 0) { + result = Integer.compare(getKeyVersion(), other.getKeyVersion()); + } + return result; + } + + @Override + public String toString() { + return metadata.toString(); + } +} diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java new file mode 100644 index 0000000000..ed026f7a90 --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionVariant; +import org.apache.orc.OrcProto; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.LocalKey; +import org.jetbrains.annotations.NotNull; + +import java.security.Key; +import java.util.ArrayList; +import java.util.List; + +public class WriterEncryptionVariant implements EncryptionVariant { + private int id; + private final WriterEncryptionKey key; + private final TypeDescription root; + private final LocalKey material; + private final OrcProto.FileStatistics.Builder fileStats = + OrcProto.FileStatistics.newBuilder(); + private final List encodings = new ArrayList<>(); + + public WriterEncryptionVariant(WriterEncryptionKey key, + TypeDescription root, + LocalKey columnKey) { + this.key = key; + this.root = root; + this.material = columnKey; + } + + @Override + public WriterEncryptionKey getKeyDescription() { + return key; + } + + public TypeDescription getRoot() { + return root; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public int getVariantId() { + return id; + } + + @Override + public Key getFileFooterKey() { + return material.getDecryptedKey(); + } + + @Override + public Key getStripeKey(long stripe) { + return material.getDecryptedKey(); + } + + public LocalKey getMaterial() { + return material; + } + + public void clearFileStatistics() { + fileStats.clearColumn(); + } + + public void addFileStatistics(OrcProto.ColumnStatistics column) { + fileStats.addColumn(column); + } + + public OrcProto.FileStatistics getFileStatistics() { + return fileStats.build(); + } + + public void addEncoding(OrcProto.ColumnEncoding encoding) { + encodings.add(encoding); + } + + public List getEncodings() { + return encodings; + } + + public void clearEncodings() { + encodings.clear(); + } + + @Override + public int hashCode() { + return key.hashCode() << 16 ^ root.getId(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other == null || other.getClass() != getClass()) { + return false; + } + return compareTo((WriterEncryptionVariant) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionVariant other) { + int result = key.compareTo(other.getKeyDescription()); + if (result == 0) { + result = Integer.compare(root.getId(), other.getRoot().getId()); + } + return result; + } +} + diff --git a/java/core/src/test/org/apache/orc/TestStringDictionary.java b/java/core/src/test/org/apache/orc/TestStringDictionary.java index 27965fe375..b0d39a0931 100644 --- a/java/core/src/test/org/apache/orc/TestStringDictionary.java +++ b/java/core/src/test/org/apache/orc/TestStringDictionary.java @@ -40,6 +40,7 @@ import org.apache.orc.impl.writer.StringTreeWriter; import org.apache.orc.impl.writer.TreeWriter; import org.apache.orc.impl.writer.WriterContext; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -172,9 +173,9 @@ static class WriterContextImpl implements WriterContext { } @Override - public OutStream createStream(int column, OrcProto.Stream.Kind kind) throws IOException { + public OutStream createStream(StreamName name) throws IOException { TestInStream.OutputCollector collect = new TestInStream.OutputCollector(); - streams.put(new StreamName(column, kind), collect); + streams.put(name, collect); return new OutStream("test", new StreamOptions(1000), collect); } @@ -223,6 +224,16 @@ public PhysicalWriter getPhysicalWriter() { return null; } + @Override + public void setEncoding(int column, WriterEncryptionVariant variant, OrcProto.ColumnEncoding encoding) { + + } + + @Override + public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats) throws IOException { + + } + @Override public OrcFile.BloomFilterVersion getBloomFilterVersion() { return OrcFile.BloomFilterVersion.UTF8; @@ -239,6 +250,16 @@ public void writeBloomFilter(StreamName name, } + @Override + public DataMask getUnencryptedMask(int columnId) { + return null; + } + + @Override + public WriterEncryptionVariant getEncryption(int columnId) { + return null; + } + @Override public boolean getUseUTCTimestamp() { return true; @@ -257,7 +278,7 @@ public void testNonDistinctDisabled() throws Exception { conf.set(OrcConf.DICTIONARY_KEY_SIZE_THRESHOLD.getAttribute(), "0.0"); WriterContextImpl writerContext = new WriterContextImpl(schema, conf); StringTreeWriter writer = (StringTreeWriter) - TreeWriter.Factory.create(schema, writerContext, true); + TreeWriter.Factory.create(schema, null, writerContext); VectorizedRowBatch batch = schema.createRowBatch(); BytesColumnVector col = (BytesColumnVector) batch.cols[0]; diff --git a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java index 658c1cea71..95f6458926 100644 --- a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java +++ b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java @@ -2099,7 +2099,7 @@ private CompressionCodec writeBatchesAndGetCodec(int count, ) throws IOException { fs.delete(testFilePath, false); PhysicalWriter physical = new PhysicalFsWriter(fs, testFilePath, opts); - CompressionCodec codec = physical.getCompressionCodec(); + CompressionCodec codec = physical.getStreamOptions().getCodec(); Writer writer = OrcFile.createWriter(testFilePath, opts.physicalWriter(physical)); writeRandomIntBytesBatches(writer, batch, count, size);