diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 8d7b454310..b45f155aae 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -181,7 +181,7 @@ impl BatchWriter { // be determined from the data buffer size (length is in bits rather than bytes) self.write_all(&arr.len().to_le_bytes())?; // write data buffer - self.write_buffer(arr.values().inner())?; + self.write_boolean_buffer(arr.values())?; // write null buffer self.write_null_buffer(arr.nulls())?; } @@ -300,8 +300,7 @@ impl BatchWriter { // write null buffer length in bits self.write_all(&buffer.len().to_le_bytes())?; // write null buffer - let buffer = buffer.inner(); - self.write_buffer(buffer)?; + self.write_boolean_buffer(buffer)?; } else { self.inner.write_all(&0_usize.to_le_bytes())?; } @@ -315,6 +314,19 @@ impl BatchWriter { self.inner.write_all(buffer.as_slice()) } + fn write_boolean_buffer(&mut self, buffer: &BooleanBuffer) -> std::io::Result<()> { + let inner_buffer = buffer.inner(); + if buffer.offset() == 0 && buffer.len() == inner_buffer.len() { + // Not a sliced buffer, write the inner buffer directly + self.write_buffer(inner_buffer)?; + } else { + // Sliced buffer, create and write the sliced buffer + let buffer = buffer.sliced(); + self.write_buffer(&buffer)?; + } + Ok(()) + } + pub fn inner(self) -> W { self.inner } @@ -621,6 +633,29 @@ mod test { assert_eq!(batch, batch2); } + #[test] + fn roundtrip_sliced() { + let batch = create_batch(8192, true); + + let mut start = 0; + let batch_size = 128; + while start < batch.num_rows() { + let end = (start + batch_size).min(batch.num_rows()); + let sliced_batch = batch.slice(start, end - start); + let buffer = Vec::new(); + let mut writer = BatchWriter::new(buffer); + writer.write_partial_schema(&sliced_batch.schema()).unwrap(); + writer.write_batch(&sliced_batch).unwrap(); + let buffer = writer.inner(); + + let mut reader = BatchReader::new(&buffer); + let batch2 = reader.read_batch().unwrap(); + assert_eq!(sliced_batch, batch2); + + start = end; + } + } + fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { let schema = Arc::new(Schema::new(vec![ Field::new("bool", DataType::Boolean, true),