-
Notifications
You must be signed in to change notification settings - Fork 737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Parquet BYTE_STREAM_SPLIT
for INT32, INT64, and FIXED_LEN_BYTE_ARRAY primitive types
#6159
Conversation
Note to reviewers: I could use some help for the following: https://github.com/etseidl/arrow-rs/blob/1c1af32c097df124d00e0ecf84ae72fdc629e250/parquet/src/encodings/decoding/byte_stream_split_decoder.rs#L151-L158 IIUC, |
I have verified that https://github.com/apache/parquet-testing/blob/master/data/byte_stream_split_extended.gzip.parquet can be read properly by both parquet-read and parquet-rewrite (and a modified parquet-rewrite can round trip properly). |
🤔 it would be great to add some sort of test that shows this -- I was hoping we already had tests that read parquet files and verified the results, but sadly it appears we do not. I suppose this is one of the things that the parquet compatibility tests I proposed on apache/parquet-format#441 would handle |
@@ -1641,6 +1643,86 @@ mod tests { | |||
assert_eq!(row_count, 300); | |||
} | |||
|
|||
#[test] | |||
fn test_read_extended_byte_stream_split() { | |||
let path = format!( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 I see here we did have the test 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that tests one path, but this bypasses the BSS decoder in encodings::decoding::byte_stream_split_decoder
. parquet-read exercises that path, so I hope to recreate that path (goes through serialized file reader) in an additional test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this test implements the suggestion from https://github.com/apache/parquet-testing/blob/master/data/README.md#additional-types
To check conformance of a BYTE_STREAM_SPLIT decoder, read each BYTE_STREAM_SPLIT-encoded column and compare the decoded values against the values from the corresponding PLAIN-encoded column. The values should be equal.
However, when I double checked the vaues with what pyarrow
python says they didn't seem to match 🤔
I printed out the f16 column:
f16_col: PrimitiveArray<Float16>
[
10.3046875,
8.9609375,
10.75,
10.9375,
8.046875,
8.6953125,
10.125,
9.6875,
9.984375,
9.1484375,
...108 elements...,
11.6015625,
9.7578125,
8.9765625,
10.1796875,
10.21875,
11.359375,
10.8359375,
10.359375,
11.4609375,
8.8125,
]
f32_col: PrimitiveArray<Float32>
[
8.827992,
9.48172,
11.511229,
10.637534,
9.301069,
8.986282,
10.032783,
8.78344,
9.328859,
10.31201,
...52 elements...,
7.6898966,
10.054354,
9.528224,
10.459386,
10.701954,
10.138242,
10.760133,
10.229212,
10.530065,
9.295327,
]
Here is what python told me:
Python 3.11.9 (main, Apr 2 2024, 08:25:04) [Clang 15.0.0 (clang-1500.3.9.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pyarrow.parquet as pq
>>> table = pq.read_table('byte_stream_split_extended.gzip.parquet')
>>> table
pyarrow.Table
float16_plain: halffloat
float16_byte_stream_split: halffloat
float_plain: float
float_byte_stream_split: float
double_plain: double
double_byte_stream_split: double
int32_plain: int32
int32_byte_stream_split: int32
int64_plain: int64
int64_byte_stream_split: int64
flba5_plain: fixed_size_binary[5]
flba5_byte_stream_split: fixed_size_binary[5]
decimal_plain: decimal128(7, 3)
decimal_byte_stream_split: decimal128(7, 3)
----
float16_plain: [[18727,18555,18784,18808,18438,...,18573,18770,18637,18687,18667]]
float16_byte_stream_split: [[18727,18555,18784,18808,18438,...,18573,18770,18637,18687,18667]]
float_plain: [[10.337575,11.407482,10.090585,10.643939,7.9498277,...,10.138242,10.760133,10.229212,10.530065,9.295327]]
float_byte_stream_split: [[10.337575,11.407482,10.090585,10.643939,7.9498277,...,10.138242,10.760133,10.229212,10.530065,9.295327]]
double_plain: [[9.82038858616854,10.196776096656958,10.820528475417419,9.606258827775427,10.521167255732113,...,9.576393393539162,9.47941158714459,10.812601287753644,10.241659719820838,8.225037940357872]]
double_byte_stream_split: [[9.82038858616854,10.196776096656958,10.820528475417419,9.606258827775427,10.521167255732113,...,9.576393393539162,9.47941158714459,10.812601287753644,10.241659719820838,8.225037940357872]]
int32_plain: [[24191,41157,7403,79368,64983,...,3584,93802,95977,73925,10300]]
int32_byte_stream_split: [[24191,41157,7403,79368,64983,...,3584,93802,95977,73925,10300]]
int64_plain: [[293650000000,41079000000,51248000000,246066000000,572141000000,...,294755000000,343501000000,663621000000,976709000000,836245000000]]
int64_byte_stream_split: [[293650000000,41079000000,51248000000,246066000000,572141000000,...,294755000000,343501000000,663621000000,976709000000,836245000000]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the pyarrow output looks like my parquet-read output (with the exception of the f16 columns). I'm not sure what happened with the f32_col
above, but I did find those values further down in the output. Weird batching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that the existing, non BSS columns (not changes by this PR) come back the same gives me confidence that the code is doing the right thing. I just found it straange that python seemed to give me a different result
BYTE_STREAM_SPLIT
for INT32, INT64, and FIXED_LEN_BYTE_ARRAY primitive types
Yes! ❤️ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @etseidl
I got some wonky results of reading, which I don't really understand (maybe I did something wrong)
I am also not sure about the changes to impl<T: DataType> Decoder<T> for ByteStreamSplitDecoder<T> {
but I left some suggestions / comments.
@@ -1641,6 +1643,86 @@ mod tests { | |||
assert_eq!(row_count, 300); | |||
} | |||
|
|||
#[test] | |||
fn test_read_extended_byte_stream_split() { | |||
let path = format!( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this test implements the suggestion from https://github.com/apache/parquet-testing/blob/master/data/README.md#additional-types
To check conformance of a BYTE_STREAM_SPLIT decoder, read each BYTE_STREAM_SPLIT-encoded column and compare the decoded values against the values from the corresponding PLAIN-encoded column. The values should be equal.
However, when I double checked the vaues with what pyarrow
python says they didn't seem to match 🤔
I printed out the f16 column:
f16_col: PrimitiveArray<Float16>
[
10.3046875,
8.9609375,
10.75,
10.9375,
8.046875,
8.6953125,
10.125,
9.6875,
9.984375,
9.1484375,
...108 elements...,
11.6015625,
9.7578125,
8.9765625,
10.1796875,
10.21875,
11.359375,
10.8359375,
10.359375,
11.4609375,
8.8125,
]
f32_col: PrimitiveArray<Float32>
[
8.827992,
9.48172,
11.511229,
10.637534,
9.301069,
8.986282,
10.032783,
8.78344,
9.328859,
10.31201,
...52 elements...,
7.6898966,
10.054354,
9.528224,
10.459386,
10.701954,
10.138242,
10.760133,
10.229212,
10.530065,
9.295327,
]
Here is what python told me:
Python 3.11.9 (main, Apr 2 2024, 08:25:04) [Clang 15.0.0 (clang-1500.3.9.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pyarrow.parquet as pq
>>> table = pq.read_table('byte_stream_split_extended.gzip.parquet')
>>> table
pyarrow.Table
float16_plain: halffloat
float16_byte_stream_split: halffloat
float_plain: float
float_byte_stream_split: float
double_plain: double
double_byte_stream_split: double
int32_plain: int32
int32_byte_stream_split: int32
int64_plain: int64
int64_byte_stream_split: int64
flba5_plain: fixed_size_binary[5]
flba5_byte_stream_split: fixed_size_binary[5]
decimal_plain: decimal128(7, 3)
decimal_byte_stream_split: decimal128(7, 3)
----
float16_plain: [[18727,18555,18784,18808,18438,...,18573,18770,18637,18687,18667]]
float16_byte_stream_split: [[18727,18555,18784,18808,18438,...,18573,18770,18637,18687,18667]]
float_plain: [[10.337575,11.407482,10.090585,10.643939,7.9498277,...,10.138242,10.760133,10.229212,10.530065,9.295327]]
float_byte_stream_split: [[10.337575,11.407482,10.090585,10.643939,7.9498277,...,10.138242,10.760133,10.229212,10.530065,9.295327]]
double_plain: [[9.82038858616854,10.196776096656958,10.820528475417419,9.606258827775427,10.521167255732113,...,9.576393393539162,9.47941158714459,10.812601287753644,10.241659719820838,8.225037940357872]]
double_byte_stream_split: [[9.82038858616854,10.196776096656958,10.820528475417419,9.606258827775427,10.521167255732113,...,9.576393393539162,9.47941158714459,10.812601287753644,10.241659719820838,8.225037940357872]]
int32_plain: [[24191,41157,7403,79368,64983,...,3584,93802,95977,73925,10300]]
int32_byte_stream_split: [[24191,41157,7403,79368,64983,...,3584,93802,95977,73925,10300]]
int64_plain: [[293650000000,41079000000,51248000000,246066000000,572141000000,...,294755000000,343501000000,663621000000,976709000000,836245000000]]
int64_byte_stream_split: [[293650000000,41079000000,51248000000,246066000000,572141000000,...,294755000000,343501000000,663621000000,976709000000,836245000000]]
@@ -76,11 +94,32 @@ impl<T: DataType> Decoder<T> for ByteStreamSplitDecoder<T> { | |||
let num_values = buffer.len().min(total_remaining_values); | |||
let buffer = &mut buffer[..num_values]; | |||
|
|||
let type_size = match T::get_physical_type() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably figure out some way to encode this in the trait -- either Decoder
directly or maybe some function
What if we made ByteStreamSpitDecoder also be parameterized in the width in bytes:
impl<T: DataType, W: const usize> Decoder<T> for ByteStreamSplitDecoder<T, W> {
...
That woudl require knowing all the possible sizes (is that known aprior?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For FIXED_LEN_BYTE_ARRAY the size can be anything :( That's why I had to add non-parameterized versions of the split_streams
and join_streams
. I added some additional likely cases (2 for FLOAT16, 16 for UUID), but for FLBA(5) there's not much you can do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 yeah -- maybe that is argument enough for creating a different decoder VariableSizedByteStreamSplitDecoder
🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that's a good idea. I was trying to specialize a ByteStreamSplitDecoder
for FixedLenByteArrayType
but that didn't work so well 😅. An entirely new decoder would get rid of all the janky casting and such.
} | ||
self.values_decoded += num_values; | ||
|
||
// FIXME(ets): there's got to be a better way to do this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree this is the thing we should try and figure out
I ran out of time to give this another look today but will try to get it tomorrow |
Thanks @alamb, but no hurry. I'm still thinking about additional tests. Once you have a look at the decoder changes, let me know if you want the same split done for the encoding side (i.e. add a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @etseidl -- I went over this again and I think it looks very nice and I think we could merge it as is
The only thing I want to do before doing so is run some benchmarks to make sure it doesn't have some unexpected performance ramifications. I have started these off and will report back.
I tried this out, but that wound up being a pretty big change because it requires passing a column descriptor when getting an Encoder.
I don't understand this comment -- it looks like you implemented VariableWidthByteStreamSplitEncoder
@@ -1641,6 +1643,86 @@ mod tests { | |||
assert_eq!(row_count, 300); | |||
} | |||
|
|||
#[test] | |||
fn test_read_extended_byte_stream_split() { | |||
let path = format!( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that the existing, non BSS columns (not changes by this PR) come back the same gives me confidence that the code is doing the right thing. I just found it straange that python seemed to give me a different result
@@ -27,6 +27,9 @@ use super::rle::RleDecoder; | |||
use crate::basic::*; | |||
use crate::data_type::private::ParquetValueType; | |||
use crate::data_type::*; | |||
use crate::encodings::decoding::byte_stream_split_decoder::{ | |||
ByteStreamSplitDecoder, VariableWidthByteStreamSplitDecoder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good model -- to have a ByteStreamSplit deocder and a VariableWidthByteStreamSplitDecoder (I realize I also partly suggested it but I like how it looks)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again! It was a good idea 😄
@@ -62,6 +63,22 @@ fn join_streams_const<const TYPE_SIZE: usize>( | |||
} | |||
} | |||
|
|||
// Like the above, but type_size is not known at compile time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it could be called join_streams_variable
to match the name of the decoder
fn set_data(&mut self, data: Bytes, num_values: usize) -> Result<()> { | ||
// Rough check that all data elements are the same length | ||
if data.len() % self.type_width != 0 { | ||
return Err(general_err!("Input data is not of fixed length")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please make this slightly more informative -- something like data length {} is not a multiple of type width {}
|
||
let stride = self.encoded_bytes.len() / type_size; | ||
match type_size { | ||
2 => join_streams_const::<2>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is the variable length decoder, is there any reason to generate code for these special case lengths (2, 4, ...)? As in it could simply call join_streams
directly 🤔
I don't think it would be all that bad but it also may be unecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My assumption (unproven) is that the parameterized join_streams is faster. So the special cases are for known logical types that use FLBA as the physical type (although I should probably remove 4 and 8). If there is no advantage, then yes, the variable width decoder should just use the non-parameterized version (and perhaps the parameterized version could just go away).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine to keep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the benchmark to work with f16 as FixedLenByteArray(2). Good news is that using the templated _static
variants is significantly faster for both encode and decode.
The bad news is that FixedLenByteArray is very slow. This is not shocking due to the need for so many buffer copies. Get it working, then get it working fast, right? 😉
% cargo bench -p parquet --bench encoding --all-features -- --baseline bssopt
Compiling parquet v52.2.0 (/Users/seidl/src/arrow-rs/parquet)
Finished `bench` profile [optimized] target(s) in 45.92s
Running benches/encoding.rs (target/release/deps/encoding-534b69246994059e)
encoding: dtype=parquet::data_type::FixedLenByteArray, encoding=BYTE_STREAM_SPLIT
time: [142.70 µs 144.02 µs 145.71 µs]
change: [+26.586% +27.751% +29.264%] (p = 0.00 < 0.05)
Performance has regressed.
Found 7 outliers among 100 measurements (7.00%)
4 (4.00%) high mild
3 (3.00%) high severe
dtype=parquet::data_type::FixedLenByteArray, encoding=BYTE_STREAM_SPLIT encoded as 32768 bytes
decoding: dtype=parquet::data_type::FixedLenByteArray, encoding=BYTE_STREAM_SPLIT
time: [392.38 µs 393.06 µs 393.77 µs]
change: [+2.0067% +2.6708% +3.2941%] (p = 0.00 < 0.05)
Performance has regressed.
Found 6 outliers among 100 measurements (6.00%)
2 (2.00%) high mild
4 (4.00%) high severe
encoding: dtype=f32, encoding=BYTE_STREAM_SPLIT
time: [44.729 µs 46.314 µs 49.430 µs]
change: [-4.2202% -1.5434% +2.7807%] (p = 0.56 > 0.05)
No change in performance detected.
Found 4 outliers among 100 measurements (4.00%)
2 (2.00%) high mild
2 (2.00%) high severe
dtype=f32, encoding=BYTE_STREAM_SPLIT encoded as 65536 bytes
decoding: dtype=f32, encoding=BYTE_STREAM_SPLIT
time: [38.613 µs 38.697 µs 38.784 µs]
change: [-0.0019% +0.5500% +1.0931%] (p = 0.06 > 0.05)
No change in performance detected.
Found 8 outliers among 100 measurements (8.00%)
4 (4.00%) high mild
4 (4.00%) high severe
encoding: dtype=f64, encoding=BYTE_STREAM_SPLIT
time: [108.86 µs 109.25 µs 109.66 µs]
change: [-4.3488% -3.0332% -1.8343%] (p = 0.00 < 0.05)
Performance has improved.
Found 3 outliers among 100 measurements (3.00%)
2 (2.00%) high mild
1 (1.00%) high severe
dtype=f64, encoding=BYTE_STREAM_SPLIT encoded as 131072 bytes
decoding: dtype=f64, encoding=BYTE_STREAM_SPLIT
time: [81.127 µs 81.343 µs 81.566 µs]
change: [-3.6443% -2.9616% -2.2527%] (p = 0.00 < 0.05)
Performance has improved.
Found 8 outliers among 100 measurements (8.00%)
5 (5.00%) high mild
3 (3.00%) high severe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bad news is that FixedLenByteArray is very slow. This is not shocking due to the need for so many buffer copies. Get it working, then get it working fast, right? 😉
Yes, I think that is right
Given what I saw of the code, the fact that decoding FixedLengthByteArray as individual Buffer
s (which each have an offset + length + arc) is going to be pretty slow.
In other words, I don't think the FixedLengthByteArray slowness is due anything speicific with BYTE_STREAM_SPLIT
. If we wanted to make it faster we would likely have to change how the parquet docoder represents the type
For example, we would likely not use this type: https://docs.rs/parquet/latest/parquet/data_type/struct.FixedLenByteArray.html
The ArrowReader has a bunch of specialized implementations for certain various array types to write directly into the arrow implementation (rather than the parquet types and then to the arrow types).
If anyone cares about reading fixed length binary more quickly from parquet it is probably good to take a look at how to optimize that more quickly (FYI @samuelcolvin and @westonpace who I think have been looking at FixedWidthBinary recently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plus the numbers here are for the row-based reader IIUC, so there shouldn't be much expectation of high performance anyway. And hopefully users won't be picking this encoding for FLBA very often. The motivating use case for this was Float16, but perhaps some small decimals encoded with FLBA would benefit as well.
One other in-the-weeds consideration is how cache unfriendly this encoding is. If you think of PLAIN data as a num_vals X type_width matrix, BSS is transposing that matrix. If type_width gets too large, there will be cache misses galore without some type of blocking during the transpose operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row-based reader IIUC
This is the key observation here, the row based reader is forced to perform an allocation for every value read and this is absolutely catastrophic from a performance standpoint. As @alamb points out the arrow readers have optimised codepaths for these types that avoid this issue, and should be preferred in use-cases that care about performance. We could probably document this more aggressively tbh...
|
||
// FIXME(ets): there's got to be a better way to do this | ||
for i in 0..num_values { | ||
if let Some(bi) = buffer[i].as_mut_any().downcast_mut::<FixedLenByteArray>() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this error/panic if the value isn't actually a FixexLenByteArary
?
Something like
let bi = buffer[i].as_mut_any().downcast_mut::<FixedLenByteArray>()
.expect("Decoding fixed length byte array");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hang on, I have a better idea...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is what I came up with:
// create a buffer from the vec so far (and leave a new Vec in its place)
let vec_with_data = std::mem::take(&mut tmp_vec);
// convert Vec to Bytes (which is a ref counted wrapper)
let bytes_with_data = Bytes::from(vec_with_data);
for (i, bi) in buffer.iter_mut().enumerate().take(num_values) {
// Get a view into the data, without also copying the bytes
let data = bytes_with_data.slice(i * type_size..(i + 1) * type_size);
let bi = bi.as_mut_any()
.downcast_mut::<FixedLenByteArray>()
.expect("Decoding fixed length byte array");
bi.set_data(data);
}
I think it avoids a bunch of allocations (only does one allocation for each batch) but it is still pretty bad in terms of the downcast_mut
stuff 🤮. I suspect we would need to add some other trait method to DataType
(like set_from_bytes
or something to make it work
@@ -53,13 +52,24 @@ fn split_streams_const<const TYPE_SIZE: usize>(src: &[u8], dst: &mut [u8]) { | |||
} | |||
} | |||
|
|||
// Like above, but type_size is not known at compile time. | |||
fn split_streams(src: &[u8], dst: &mut [u8], type_size: usize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here -- maybe
fn split_streams(src: &[u8], dst: &mut [u8], type_size: usize) { | |
fn split_streams_variable(src: &[u8], dst: &mut [u8], type_size: usize) { |
Should have gone back and edited...after making some changes for the test code, I figured most of the changes needed were already present, so I went ahead and remove the byte width from the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the benchmarks and I see this which suggests that this branch is slower than master somehow for f32. I will rerun to see if I can reproduce the results
group bss master
----- --- ------
decoding: dtype=f32, encoding=BYTE_STREAM_SPLIT 1.00 30.5±0.03µs ? ?/sec 1.00 30.5±0.04µs ? ?/sec
decoding: dtype=f64, encoding=BYTE_STREAM_SPLIT 1.00 65.3±0.06µs ? ?/sec 1.00 65.3±0.10µs ? ?/sec
encoding: dtype=f32, encoding=BYTE_STREAM_SPLIT 1.22 41.6±0.03µs ? ?/sec 1.00 34.2±0.05µs ? ?/sec
encoding: dtype=f64, encoding=BYTE_STREAM_SPLIT 1.01 95.1±0.79µs ? ?/sec 1.00 93.8±0.72µs ? ?/sec
And the next run shows the same result somehow 🤔
group bss master
----- --- ------
decoding: dtype=f32, encoding=BYTE_STREAM_SPLIT 1.00 30.5±0.13µs ? ?/sec 1.00 30.5±0.04µs ? ?/sec
decoding: dtype=f64, encoding=BYTE_STREAM_SPLIT 1.00 65.2±0.05µs ? ?/sec 1.00 65.3±0.05µs ? ?/sec
encoding: dtype=f32, encoding=BYTE_STREAM_SPLIT 1.23 41.6±0.09µs ? ?/sec 1.00 33.7±0.25µs ? ?/sec
encoding: dtype=f64, encoding=BYTE_STREAM_SPLIT 1.01 94.7±0.81µs ? ?/sec 1.00 93.4±0.32µs ? ?/sec
I ran this command to benchmark, btw: cargo bench -p parquet --bench encoding --all-features -- --save-baseline master |
Odd. I'll play around some more, but on my laptop I'm not seeing the discrepancy. I'll try on some different hardware. |
I'll double check too |
On my workstation I'm seeing pretty consistent numbers for f32, but the new f64 encode is around 1-2% slower and f64 decode is pretty consistently 5% faster. I wonder if the benchmarks are just really sensitive to architecture. |
On my workstation I'm seeing pretty consistent numbers for f32, but the new f64 encode is around 1-2% slower and f64 decode is pretty consistently 5% faster. I wonder if the benchmarks are just really sensitive to architecture. Given the numbers are reported in usec and we didn't really change anything related to f32 decoding, I would tend to agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good to me. Thank you @etseidl
|
||
let stride = self.encoded_bytes.len() / type_size; | ||
match type_size { | ||
2 => join_streams_const::<2>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine to keep
for (i, bi) in buffer.iter_mut().enumerate().take(num_values) { | ||
// Get a view into the data, without also copying the bytes | ||
let data = bytes_with_data.slice(i * type_size..(i + 1) * type_size); | ||
// TODO: perhaps add a `set_from_bytes` method to `DataType` to avoid downcasting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe @tustvold or @XiangpengHao has some suggestion on how to avoid this downcasting
I've done some more performance tweaking. By reworking
The new code replaces the current put logic values.iter().for_each(|x| {
let bytes = x.as_bytes();
...
self.buffer.extend(bytes)
}); with a parameterized function fn put_fixed<T: DataType, const TYPE_SIZE: usize>(dst: &mut [u8], values: &[T::T]) {
let mut idx = 0;
values.iter().for_each(|x| {
let bytes = x.as_bytes();
...
for i in 0..TYPE_SIZE {
dst[idx + i] = bytes[i]
}
idx += TYPE_SIZE;
});
} for I'll push the new code once I have a roundtrip test to make sure it's working correctly. I also want to benchmark on a faster machine. In a subsequent PR I think I'll try tackling a more cache friendly transpose for Edit: changing the loop to fn put_fixed<T: DataType, const TYPE_SIZE: usize>(dst: &mut [u8], values: &[T::T]) {
let mut idx = 0;
values.iter().for_each(|x| {
let bytes = x.as_bytes();
...
dst[idx..(TYPE_SIZE + idx)].copy_from_slice(&bytes[..TYPE_SIZE]);
idx += TYPE_SIZE;
});
} |
I went ahead and optimized
|
// Now copy `values` into the buffer. For `type_width` <= 8 use a fixed size when | ||
// performing the copy as it is significantly faster. | ||
match self.type_width { | ||
2 => put_fixed::<T, 2>(out_buf, values), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if type_width == 1, still put_variable
would be called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still wondering why type_width
<= 8 by hande
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could throw 1 in, but FLBA(1) is kind of weird. An int would be much faster to deal with for a single byte. I suppose someone might be tempted to use it for a single (ASCII) character field...UTF8 would need multiple bytes anyway.
Still wondering why type_width <= 8 by hande
The reason for the special handling for 2-8 is shown by the benchmarks...those numbers are basically the current code vs using put_variable
exclusively. For Float16
using put_fixed
is more than 2X faster. The speed advantage pretty much goes away at type_width == 8
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is ok, even if it is unlikely that 5 byte fixed length byte arrays are an important usecase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Now copy `values` into the buffer. For `type_width` <= 8 use a fixed size when | ||
// performing the copy as it is significantly faster. | ||
match self.type_width { | ||
2 => put_fixed::<T, 2>(out_buf, values), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is ok, even if it is unlikely that 5 byte fixed length byte arrays are an important usecase
🚀 |
Which issue does this PR close?
Closes #6048.
Rationale for this change
BYTE_STREAM_SPLIT
encoding was recently expanded to include all fixed-width primitive types (primarily to support theFloat16
logical type, but it has been found to be beneficial for integer types as well).What changes are included in this PR?
The biggest change is adding the
type_length
from the Parquet schema to the encoder and decoder interface. This is necessary to handleFIXED_LEN_BYTE_ARRAY
data.Are there any user-facing changes?
Adds new data types to an existing encoding. May require additional documentation.