Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cpp/src/arrow/ipc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor {
num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb));

// Need to write 4 bytes (metadata size), the metadata, plus padding to
// fall on a 64-byte offset
int64_t padded_metadata_length =
BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4);
// fall on an 8-byte offset
int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4);

// The returned metadata size includes the length prefix, the flatbuffer,
// plus padding
*metadata_length = padded_metadata_length;
*metadata_length = static_cast<int32_t>(padded_metadata_length);

// Write the flatbuffer size prefix
int32_t flatbuffer_size = metadata_fb->size();
Expand Down Expand Up @@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length,
return Status::Invalid(ss.str());
}

*metadata = std::make_shared<RecordBatchMetadata>(buffer, sizeof(int32_t));
std::shared_ptr<Message> message;
RETURN_NOT_OK(Message::Open(buffer, 4, &message));
*metadata = std::make_shared<RecordBatchMetadata>(message);
return Status::OK();
}

Expand Down
21 changes: 4 additions & 17 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length,
Status WriteRecordBatchMetadata(int32_t length, int64_t body_length,
const std::vector<flatbuf::FieldNode>& nodes,
const std::vector<flatbuf::Buffer>& buffers, std::shared_ptr<Buffer>* out) {
flatbuffers::FlatBufferBuilder fbb;

auto batch = flatbuf::CreateRecordBatch(
fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers));

fbb.Finish(batch);

int32_t size = fbb.GetSize();

auto result = std::make_shared<PoolBuffer>();
RETURN_NOT_OK(result->Resize(size));

uint8_t* dst = result->mutable_data();
memcpy(dst, fbb.GetBufferPointer(), size);

*out = result;
return Status::OK();
MessageBuilder builder;
RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers));
RETURN_NOT_OK(builder.Finish());
return builder.GetBuffer(out);
}

Status MessageBuilder::Finish() {
Expand Down
2 changes: 1 addition & 1 deletion format/File.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct Block {

/// Length of the data (this is aligned so there can be a gap between this and
/// the metatdata).
bodyLength: int;
bodyLength: long;
}

root_type Footer;
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ public class ArrowBlock implements FBSerializable {

private final long offset;
private final int metadataLength;
private final int bodyLength;
private final long bodyLength;

public ArrowBlock(long offset, int metadataLength, int bodyLength) {
public ArrowBlock(long offset, int metadataLength, long bodyLength) {
super();
this.offset = offset;
this.metadataLength = metadataLength;
Expand All @@ -43,7 +43,7 @@ public int getMetadataLength() {
return metadataLength;
}

public int getBodyLength() {
public long getBodyLength() {
return bodyLength;
}

Expand All @@ -56,7 +56,7 @@ public int writeTo(FlatBufferBuilder builder) {
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + bodyLength;
result = prime * result + (int) (bodyLength ^ (bodyLength >>> 32));
result = prime * result + metadataLength;
result = prime * result + (int) (offset ^ (offset >>> 32));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,33 +70,24 @@ public static int bytesToInt(byte[] bytes) {
*/
public static long serialize(WriteChannel out, Schema schema) throws IOException {
FlatBufferBuilder builder = new FlatBufferBuilder();
builder.finish(schema.getSchema(builder));
ByteBuffer serializedBody = builder.dataBuffer();
ByteBuffer serializedHeader =
serializeHeader(MessageHeader.Schema, serializedBody.remaining());

long size = out.writeIntLittleEndian(serializedHeader.remaining());
size += out.write(serializedHeader);
size += out.write(serializedBody);
int schemaOffset = schema.getSchema(builder);
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0);
long size = out.writeIntLittleEndian(serializedMessage.remaining());
size += out.write(serializedMessage);
return size;
}

/**
* Deserializes a schema object. Format is from serialize().
*/
public static Schema deserializeSchema(ReadChannel in) throws IOException {
Message header = deserializeHeader(in, MessageHeader.Schema);
if (header == null) {
Message message = deserializeMessage(in, MessageHeader.Schema);
if (message == null) {
throw new IOException("Unexpected end of input. Missing schema.");
}

// Now read the schema.
ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength());
if (in.readFully(buffer) != header.bodyLength()) {
throw new IOException("Unexpected end of input trying to read schema.");
}
buffer.rewind();
return Schema.deserialize(buffer);
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
message.header(new org.apache.arrow.flatbuf.Schema()));
}

/**
Expand All @@ -106,37 +97,22 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
throws IOException {
long start = out.getCurrentPosition();
int bodyLength = batch.computeBodyLength();
ByteBuffer metadata = WriteChannel.serialize(batch);

int messageLength = 4 + metadata.remaining() + bodyLength;
ByteBuffer serializedHeader =
serializeHeader(MessageHeader.RecordBatch, messageLength);

// Compute the required alignment. This is not a great way to do it. The issue is
// that we need to know the message size to serialize the message header but the
// size depends on the alignment, which depends on the message header.
// This will serialize the header again with the updated size alignment adjusted.
// TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above.
// Is there a way to do this?
long bufferOffset = start + 4 + serializedHeader.remaining() + 4 + metadata.remaining();
if (bufferOffset % 8 != 0) {
messageLength += 8 - bufferOffset % 8;
serializedHeader = serializeHeader(MessageHeader.RecordBatch, messageLength);
}

// Write message header.
out.writeIntLittleEndian(serializedHeader.remaining());
out.write(serializedHeader);
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = batch.writeTo(builder);

ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch,
batchOffset, bodyLength);

// Write batch header. with the 4 byte little endian prefix
out.writeIntLittleEndian(metadata.remaining());
int metadataSize = metadata.remaining();
long batchStart = out.getCurrentPosition();
out.write(metadata);
long metadataStart = out.getCurrentPosition();
out.writeIntLittleEndian(serializedMessage.remaining());
out.write(serializedMessage);

// Align the output to 8 byte boundary.
out.align();

long metadataSize = out.getCurrentPosition() - metadataStart;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's significantly simpler if you include the padding in the metadata size. The Flatbuffers library shouldn't have a problem with extra bytes at the end of the buffer (it doesn't in C++)


long bufferStart = out.getCurrentPosition();
List<ArrowBuf> buffers = batch.getBuffers();
List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
Expand All @@ -154,31 +130,31 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
" != " + startPosition + layout.getSize());
}
}
return new ArrowBlock(batchStart, metadataSize, (int)(out.getCurrentPosition() - bufferStart));
return new ArrowBlock(start, (int) metadataSize, out.getCurrentPosition() - bufferStart);
}

/**
* Deserializes a RecordBatch
*/
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
BufferAllocator alloc) throws IOException {
Message header = deserializeHeader(in, MessageHeader.RecordBatch);
if (header == null) return null;
Message message = deserializeMessage(in, MessageHeader.RecordBatch);
if (message == null) return null;

int messageLen = (int)header.bodyLength();
// Now read the buffer. This has the metadata followed by the data.
ArrowBuf buffer = alloc.buffer(messageLen);
long readPosition = in.getCurrentPositiion();
if (in.readFully(buffer, messageLen) != messageLen) {
throw new IOException("Unexpected end of input trying to read batch.");
if (message.bodyLength() > Integer.MAX_VALUE) {
throw new IOException("Cannot currently deserialize record batches over 2GB");
}

// Read the length of the metadata.
int metadataLen = buffer.readInt();
buffer = buffer.slice(4, messageLen - 4);
readPosition += 4;
messageLen -= 4;
return deserializeRecordBatch(buffer, readPosition, metadataLen, messageLen);
RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());

int bodyLength = (int) message.bodyLength();

// Now read the record batch body
ArrowBuf buffer = alloc.buffer(bodyLength);
if (in.readFully(buffer, bodyLength) != bodyLength) {
throw new IOException("Unexpected end of input trying to read batch.");
}
return deserializeRecordBatch(recordBatchFB, buffer);
}

/**
Expand All @@ -188,36 +164,41 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block,
BufferAllocator alloc) throws IOException {
long readPosition = in.getCurrentPositiion();
int totalLen = block.getMetadataLength() + block.getBodyLength();
if ((readPosition + block.getMetadataLength()) % 8 != 0) {
// Compute padded size.
totalLen += (8 - (readPosition + block.getMetadataLength()) % 8);

// Metadata length contains byte padding
long totalLen = block.getMetadataLength() + block.getBodyLength();

if (totalLen > Integer.MAX_VALUE) {
throw new IOException("Cannot currently deserialize record batches over 2GB");
}

ArrowBuf buffer = alloc.buffer(totalLen);
if (in.readFully(buffer, totalLen) != totalLen) {
ArrowBuf buffer = alloc.buffer((int) totalLen);
if (in.readFully(buffer, (int) totalLen) != totalLen) {
throw new IOException("Unexpected end of input trying to read batch.");
}

return deserializeRecordBatch(buffer, readPosition, block.getMetadataLength(), totalLen);
return deserializeRecordBatch(buffer, block.getMetadataLength(), (int) totalLen);
}

// Deserializes a record batch. Buffer should start at the RecordBatch and include
// all the bytes for the metadata and then data buffers.
private static ArrowRecordBatch deserializeRecordBatch(
ArrowBuf buffer, long readPosition, int metadataLen, int bufferLen) {
private static ArrowRecordBatch deserializeRecordBatch(ArrowBuf buffer, int metadataLen,
int bufferLen) {
// Read the metadata.
RecordBatch recordBatchFB =
RecordBatch.getRootAsRecordBatch(buffer.nioBuffer().asReadOnlyBuffer());

int bufferOffset = metadataLen;
readPosition += bufferOffset;
if (readPosition % 8 != 0) {
bufferOffset += (int)(8 - readPosition % 8);
}

// Now read the body
final ArrowBuf body = buffer.slice(bufferOffset, bufferLen - bufferOffset);
return deserializeRecordBatch(recordBatchFB, body);
}

// Deserializes a record batch given the Flatbuffer metadata and in-memory body
private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB,
ArrowBuf body) {
// Now read the body
int nodesLength = recordBatchFB.nodesLength();
List<ArrowFieldNode> nodes = new ArrayList<>();
for (int i = 0; i < nodesLength; ++i) {
Expand All @@ -232,43 +213,44 @@ private static ArrowRecordBatch deserializeRecordBatch(
}
ArrowRecordBatch arrowRecordBatch =
new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers);
buffer.release();
body.release();
return arrowRecordBatch;
}

/**
* Serializes a message header.
*/
private static ByteBuffer serializeHeader(byte headerType, int bodyLength) {
FlatBufferBuilder headerBuilder = new FlatBufferBuilder();
Message.startMessage(headerBuilder);
Message.addHeaderType(headerBuilder, headerType);
Message.addVersion(headerBuilder, MetadataVersion.V1);
Message.addBodyLength(headerBuilder, bodyLength);
headerBuilder.finish(Message.endMessage(headerBuilder));
return headerBuilder.dataBuffer();
private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType,
int headerOffset, int bodyLength) {
Message.startMessage(builder);
Message.addHeaderType(builder, headerType);
Message.addHeader(builder, headerOffset);
Message.addVersion(builder, MetadataVersion.V1);
Message.addBodyLength(builder, bodyLength);
builder.finish(Message.endMessage(builder));
return builder.dataBuffer();
}

private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException {
// Read the header size. There is an i32 little endian prefix.
private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException {
// Read the message size. There is an i32 little endian prefix.
ByteBuffer buffer = ByteBuffer.allocate(4);
if (in.readFully(buffer) != 4) {
return null;
}

int headerLength = bytesToInt(buffer.array());
buffer = ByteBuffer.allocate(headerLength);
if (in.readFully(buffer) != headerLength) {
int messageLength = bytesToInt(buffer.array());
buffer = ByteBuffer.allocate(messageLength);
if (in.readFully(buffer) != messageLength) {
throw new IOException(
"Unexpected end of stream trying to read header.");
"Unexpected end of stream trying to read message.");
}
buffer.rewind();

Message header = Message.getRootAsMessage(buffer);
if (header.headerType() != headerType) {
Message message = Message.getRootAsMessage(buffer);
if (message.headerType() != headerType) {
throw new IOException("Invalid message: expecting " + headerType +
". Message contained: " + header.headerType());
". Message contained: " + message.headerType());
}
return header;
return message;
}
}