Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions native/core/src/execution/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl<W: Write> BatchWriter<W> {
// be determined from the data buffer size (length is in bits rather than bytes)
self.write_all(&arr.len().to_le_bytes())?;
// write data buffer
self.write_buffer(arr.values().inner())?;
self.write_boolean_buffer(arr.values())?;
// write null buffer
self.write_null_buffer(arr.nulls())?;
}
Expand Down Expand Up @@ -300,8 +300,7 @@ impl<W: Write> BatchWriter<W> {
// write null buffer length in bits
self.write_all(&buffer.len().to_le_bytes())?;
// write null buffer
let buffer = buffer.inner();
self.write_buffer(buffer)?;
self.write_boolean_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
Expand All @@ -315,6 +314,19 @@ impl<W: Write> BatchWriter<W> {
self.inner.write_all(buffer.as_slice())
}

fn write_boolean_buffer(&mut self, buffer: &BooleanBuffer) -> std::io::Result<()> {
let inner_buffer = buffer.inner();
if buffer.offset() == 0 && buffer.len() == inner_buffer.len() {
// Not a sliced buffer, write the inner buffer directly
self.write_buffer(inner_buffer)?;
} else {
// Sliced buffer, create and write the sliced buffer
let buffer = buffer.sliced();
self.write_buffer(&buffer)?;
}
Ok(())
}

pub fn inner(self) -> W {
self.inner
}
Expand Down Expand Up @@ -621,6 +633,29 @@ mod test {
assert_eq!(batch, batch2);
}

#[test]
fn roundtrip_sliced() {
let batch = create_batch(8192, true);

let mut start = 0;
let batch_size = 128;
while start < batch.num_rows() {
let end = (start + batch_size).min(batch.num_rows());
let sliced_batch = batch.slice(start, end - start);
let buffer = Vec::new();
let mut writer = BatchWriter::new(buffer);
writer.write_partial_schema(&sliced_batch.schema()).unwrap();
writer.write_batch(&sliced_batch).unwrap();
let buffer = writer.inner();

let mut reader = BatchReader::new(&buffer);
let batch2 = reader.read_batch().unwrap();
assert_eq!(sliced_batch, batch2);

start = end;
}
}

fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("bool", DataType::Boolean, true),
Expand Down
Loading