Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 117 additions & 13 deletions rust/arrow/src/compute/kernels/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ where
/// Returns the sum of values in the array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(not(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd")))]
pub fn sum<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T: ArrowNumericType,
Expand All @@ -136,23 +137,126 @@ where
return None;
}

let mut n: T::Native = T::default_value();
let data = array.data();
let m = array.value_slice(0, data.len());
let data: &[T::Native] = array.value_slice(0, array.len());

if null_count == 0 {
// optimized path for arrays without null values
for item in m.iter().take(data.len()) {
n = n + *item;
match array.data().null_buffer() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator + *value
});

Some(sum)
}
} else {
for (i, item) in m.iter().enumerate() {
if data.is_valid(i) {
n = n + *item;
}
Some(buffer) => {
let mut sum = T::default_value();
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = buffer.bit_chunks(array.offset(), array.len());
&data_chunks
.zip(bit_chunks.iter())
.for_each(|(chunk, mask)| {
chunk.iter().enumerate().for_each(|(i, value)| {
if (mask & (1 << i)) != 0 {
sum = sum + *value;
}
});
});

let remainder_bits = bit_chunks.remainder_bits();

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum + *value;
}
});

Some(sum)
}
}
Some(n)
}

/// Returns the sum of values in the array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T::Native: Add<Output = T::Native>,
{
let null_count = array.null_count();

if null_count == array.len() {
return None;
}

let data: &[T::Native] = array.value_slice(0, array.len());

let mut vector_sum = T::init(T::default_value());
let mut remainder_sum = T::default_value();

match array.data().null_buffer() {
None => {
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

data_chunks.for_each(|chunk| {
chunk.chunks_exact(T::lanes()).for_each(|chunk| {
let chunk = T::load(&chunk);
vector_sum = vector_sum + chunk;
});
});

remainder.iter().for_each(|value| {
remainder_sum = remainder_sum + *value;
});
}
Some(buffer) => {
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = buffer.bit_chunks(array.offset(), array.len());
let remainder_bits = bit_chunks.remainder_bits();

data_chunks.zip(bit_chunks).for_each(|(chunk, mut mask)| {
// split chunks further into slices corresponding to the vector length
// the compiler is able to unroll this inner loop and remove bounds checks
// since the outer chunk size (64) is always a multiple of the number of lanes
chunk.chunks_exact(T::lanes()).for_each(|chunk| {
let zero = T::init(T::default_value());
let vecmask = T::mask_from_u64(mask);
let chunk = T::load(&chunk);
let blended = T::mask_select(vecmask, chunk, zero);

vector_sum = vector_sum + blended;

mask = mask >> T::lanes();
});
});

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
remainder_sum = remainder_sum + *value;
}
});
}
}

// calculate horizontal sum of accumulator by writing to a temporary
// this is probably faster than extracting individual lanes
// the compiler is free to optimize this to something faster
let tmp = &mut [T::default_value(); 64];
T::write(vector_sum, &mut tmp[0..T::lanes()]);

let mut total_sum = T::default_value();
tmp[0..T::lanes()]
.iter()
.for_each(|lane| total_sum = total_sum + *lane);

total_sum = total_sum + remainder_sum;

Some(total_sum)
}

#[cfg(test)]
Expand Down
Loading