diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index 2b5ef11f861..7b4d18c267d 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor { num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb)); // Need to write 4 bytes (metadata size), the metadata, plus padding to - // fall on a 64-byte offset - int64_t padded_metadata_length = - BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4); + // fall on an 8-byte offset + int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4); // The returned metadata size includes the length prefix, the flatbuffer, // plus padding - *metadata_length = padded_metadata_length; + *metadata_length = static_cast(padded_metadata_length); // Write the flatbuffer size prefix int32_t flatbuffer_size = metadata_fb->size(); @@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length, return Status::Invalid(ss.str()); } - *metadata = std::make_shared(buffer, sizeof(int32_t)); + std::shared_ptr message; + RETURN_NOT_OK(Message::Open(buffer, 4, &message)); + *metadata = std::make_shared(message); return Status::OK(); } diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index 16069a8f9dc..cc160c42ec9 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length, Status WriteRecordBatchMetadata(int32_t length, int64_t body_length, const std::vector& nodes, const std::vector& buffers, std::shared_ptr* out) { - flatbuffers::FlatBufferBuilder fbb; - - auto batch = flatbuf::CreateRecordBatch( - fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers)); - - fbb.Finish(batch); - - int32_t size = fbb.GetSize(); - - auto result = std::make_shared(); - RETURN_NOT_OK(result->Resize(size)); - - uint8_t* dst = result->mutable_data(); - memcpy(dst, fbb.GetBufferPointer(), size); - - *out = result; - return Status::OK(); + MessageBuilder builder; + RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers)); + RETURN_NOT_OK(builder.Finish()); + return builder.GetBuffer(out); } Status MessageBuilder::Finish() { diff --git a/format/File.fbs b/format/File.fbs index 86b4b22a92d..e8d6da4f848 100644 --- a/format/File.fbs +++ b/format/File.fbs @@ -43,7 +43,7 @@ struct Block { /// Length of the data (this is aligned so there can be a gap between this and /// the metatdata). - bodyLength: int; + bodyLength: long; } root_type Footer; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowBlock.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowBlock.java index a55c283f40b..90fb02b0597 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowBlock.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowBlock.java @@ -26,9 +26,9 @@ public class ArrowBlock implements FBSerializable { private final long offset; private final int metadataLength; - private final int bodyLength; + private final long bodyLength; - public ArrowBlock(long offset, int metadataLength, int bodyLength) { + public ArrowBlock(long offset, int metadataLength, long bodyLength) { super(); this.offset = offset; this.metadataLength = metadataLength; @@ -43,7 +43,7 @@ public int getMetadataLength() { return metadataLength; } - public int getBodyLength() { + public long getBodyLength() { return bodyLength; } @@ -56,7 +56,7 @@ public int writeTo(FlatBufferBuilder builder) { public int hashCode() { final int prime = 31; int result = 1; - result = prime * result + bodyLength; + result = prime * result + (int) (bodyLength ^ (bodyLength >>> 32)); result = prime * result + metadataLength; result = prime * result + (int) (offset ^ (offset >>> 32)); return result; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java index 61f59ae39e7..02bfd6b0975 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java @@ -70,14 +70,10 @@ public static int bytesToInt(byte[] bytes) { */ public static long serialize(WriteChannel out, Schema schema) throws IOException { FlatBufferBuilder builder = new FlatBufferBuilder(); - builder.finish(schema.getSchema(builder)); - ByteBuffer serializedBody = builder.dataBuffer(); - ByteBuffer serializedHeader = - serializeHeader(MessageHeader.Schema, serializedBody.remaining()); - - long size = out.writeIntLittleEndian(serializedHeader.remaining()); - size += out.write(serializedHeader); - size += out.write(serializedBody); + int schemaOffset = schema.getSchema(builder); + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0); + long size = out.writeIntLittleEndian(serializedMessage.remaining()); + size += out.write(serializedMessage); return size; } @@ -85,18 +81,13 @@ public static long serialize(WriteChannel out, Schema schema) throws IOException * Deserializes a schema object. Format is from serialize(). */ public static Schema deserializeSchema(ReadChannel in) throws IOException { - Message header = deserializeHeader(in, MessageHeader.Schema); - if (header == null) { + Message message = deserializeMessage(in, MessageHeader.Schema); + if (message == null) { throw new IOException("Unexpected end of input. Missing schema."); } - // Now read the schema. - ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength()); - if (in.readFully(buffer) != header.bodyLength()) { - throw new IOException("Unexpected end of input trying to read schema."); - } - buffer.rewind(); - return Schema.deserialize(buffer); + return Schema.convertSchema((org.apache.arrow.flatbuf.Schema) + message.header(new org.apache.arrow.flatbuf.Schema())); } /** @@ -106,33 +97,23 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) throws IOException { long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); - ByteBuffer metadata = WriteChannel.serialize(batch); - - int messageLength = 4 + metadata.remaining() + bodyLength; - ByteBuffer serializedHeader = - serializeHeader(MessageHeader.RecordBatch, messageLength); - - // Compute the required alignment. This is not a great way to do it. The issue is - // that we need to know the message size to serialize the message header but the - // size depends on the alignment, which depends on the message header. - // This will serialize the header again with the updated size alignment adjusted. - // TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above. - // Is there a way to do this? - long bufferOffset = start + 4 + serializedHeader.remaining() + 4 + metadata.remaining(); - if (bufferOffset % 8 != 0) { - messageLength += 8 - bufferOffset % 8; - serializedHeader = serializeHeader(MessageHeader.RecordBatch, messageLength); - } - // Write message header. - out.writeIntLittleEndian(serializedHeader.remaining()); - out.write(serializedHeader); + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, + batchOffset, bodyLength); + + int metadataLength = serializedMessage.remaining(); - // Write batch header. with the 4 byte little endian prefix - out.writeIntLittleEndian(metadata.remaining()); - int metadataSize = metadata.remaining(); - long batchStart = out.getCurrentPosition(); - out.write(metadata); + // Add extra padding bytes so that length prefix + metadata is a multiple + // of 8 after alignment + if ((metadataLength + 4) % 8 != 0) { + metadataLength += 8 - (metadataLength + 4) % 8; + } + + out.writeIntLittleEndian(metadataLength); + out.write(serializedMessage); // Align the output to 8 byte boundary. out.align(); @@ -154,7 +135,8 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) " != " + startPosition + layout.getSize()); } } - return new ArrowBlock(batchStart, metadataSize, (int)(out.getCurrentPosition() - bufferStart)); + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart); } /** @@ -162,23 +144,23 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) */ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, BufferAllocator alloc) throws IOException { - Message header = deserializeHeader(in, MessageHeader.RecordBatch); - if (header == null) return null; - - int messageLen = (int)header.bodyLength(); - // Now read the buffer. This has the metadata followed by the data. - ArrowBuf buffer = alloc.buffer(messageLen); - long readPosition = in.getCurrentPositiion(); - if (in.readFully(buffer, messageLen) != messageLen) { - throw new IOException("Unexpected end of input trying to read batch."); + Message message = deserializeMessage(in, MessageHeader.RecordBatch); + if (message == null) return null; + + if (message.bodyLength() > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); } - // Read the length of the metadata. - int metadataLen = buffer.readInt(); - buffer = buffer.slice(4, messageLen - 4); - readPosition += 4; - messageLen -= 4; - return deserializeRecordBatch(buffer, readPosition, metadataLen, messageLen); + RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch()); + + int bodyLength = (int) message.bodyLength(); + + // Now read the record batch body + ArrowBuf buffer = alloc.buffer(bodyLength); + if (in.readFully(buffer, bodyLength) != bodyLength) { + throw new IOException("Unexpected end of input trying to read batch."); + } + return deserializeRecordBatch(recordBatchFB, buffer); } /** @@ -187,37 +169,39 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, */ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block, BufferAllocator alloc) throws IOException { - long readPosition = in.getCurrentPositiion(); - int totalLen = block.getMetadataLength() + block.getBodyLength(); - if ((readPosition + block.getMetadataLength()) % 8 != 0) { - // Compute padded size. - totalLen += (8 - (readPosition + block.getMetadataLength()) % 8); + // Metadata length contains integer prefix plus byte padding + long totalLen = block.getMetadataLength() + block.getBodyLength(); + + if (totalLen > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); } - ArrowBuf buffer = alloc.buffer(totalLen); - if (in.readFully(buffer, totalLen) != totalLen) { + ArrowBuf buffer = alloc.buffer((int) totalLen); + if (in.readFully(buffer, (int) totalLen) != totalLen) { throw new IOException("Unexpected end of input trying to read batch."); } - return deserializeRecordBatch(buffer, readPosition, block.getMetadataLength(), totalLen); - } + ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); - // Deserializes a record batch. Buffer should start at the RecordBatch and include - // all the bytes for the metadata and then data buffers. - private static ArrowRecordBatch deserializeRecordBatch( - ArrowBuf buffer, long readPosition, int metadataLen, int bufferLen) { // Read the metadata. RecordBatch recordBatchFB = - RecordBatch.getRootAsRecordBatch(buffer.nioBuffer().asReadOnlyBuffer()); + RecordBatch.getRootAsRecordBatch(metadataBuffer.nioBuffer().asReadOnlyBuffer()); - int bufferOffset = metadataLen; - readPosition += bufferOffset; - if (readPosition % 8 != 0) { - bufferOffset += (int)(8 - readPosition % 8); - } + // Now read the body + final ArrowBuf body = buffer.slice(block.getMetadataLength(), + (int) totalLen - block.getMetadataLength()); + ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body); + metadataBuffer.release(); + buffer.release(); + + return result; + } + + // Deserializes a record batch given the Flatbuffer metadata and in-memory body + private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB, + ArrowBuf body) { // Now read the body - final ArrowBuf body = buffer.slice(bufferOffset, bufferLen - bufferOffset); int nodesLength = recordBatchFB.nodesLength(); List nodes = new ArrayList<>(); for (int i = 0; i < nodesLength; ++i) { @@ -232,43 +216,44 @@ private static ArrowRecordBatch deserializeRecordBatch( } ArrowRecordBatch arrowRecordBatch = new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers); - buffer.release(); + body.release(); return arrowRecordBatch; } /** * Serializes a message header. */ - private static ByteBuffer serializeHeader(byte headerType, int bodyLength) { - FlatBufferBuilder headerBuilder = new FlatBufferBuilder(); - Message.startMessage(headerBuilder); - Message.addHeaderType(headerBuilder, headerType); - Message.addVersion(headerBuilder, MetadataVersion.V1); - Message.addBodyLength(headerBuilder, bodyLength); - headerBuilder.finish(Message.endMessage(headerBuilder)); - return headerBuilder.dataBuffer(); + private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType, + int headerOffset, int bodyLength) { + Message.startMessage(builder); + Message.addHeaderType(builder, headerType); + Message.addHeader(builder, headerOffset); + Message.addVersion(builder, MetadataVersion.V1); + Message.addBodyLength(builder, bodyLength); + builder.finish(Message.endMessage(builder)); + return builder.dataBuffer(); } - private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException { - // Read the header size. There is an i32 little endian prefix. + private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException { + // Read the message size. There is an i32 little endian prefix. ByteBuffer buffer = ByteBuffer.allocate(4); if (in.readFully(buffer) != 4) { return null; } - int headerLength = bytesToInt(buffer.array()); - buffer = ByteBuffer.allocate(headerLength); - if (in.readFully(buffer) != headerLength) { + int messageLength = bytesToInt(buffer.array()); + buffer = ByteBuffer.allocate(messageLength); + if (in.readFully(buffer) != messageLength) { throw new IOException( - "Unexpected end of stream trying to read header."); + "Unexpected end of stream trying to read message."); } buffer.rewind(); - Message header = Message.getRootAsMessage(buffer); - if (header.headerType() != headerType) { + Message message = Message.getRootAsMessage(buffer); + if (message.headerType() != headerType) { throw new IOException("Invalid message: expecting " + headerType + - ". Message contained: " + header.headerType()); + ". Message contained: " + message.headerType()); } - return header; + return message; } }