Skip to content

Commit 7bc6214

Browse files
authored
fix: Properly zip struct validities (#18886)
1 parent e9d835d commit 7bc6214

File tree

4 files changed

+265
-69
lines changed

4 files changed

+265
-69
lines changed

crates/polars-arrow/src/array/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static {
178178
/// The caller must ensure that `offset + length <= self.len()`
179179
#[must_use]
180180
unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Box<dyn Array> {
181+
debug_assert!(offset + length <= self.len());
181182
let mut new = self.to_boxed();
182183
new.slice_unchecked(offset, length);
183184
new

crates/polars-core/src/chunked_array/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,8 @@ where
768768
})
769769
.collect();
770770

771+
debug_assert_eq!(offset, array.len());
772+
771773
// SAFETY: We just slice the original chunks, their type will not change.
772774
unsafe {
773775
Self::from_chunks_and_dtype(self.name().clone(), chunks, self.dtype().clone())

crates/polars-core/src/chunked_array/ops/zip.rs

+233-69
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::borrow::Cow;
22

3-
use arrow::bitmap::Bitmap;
3+
use arrow::bitmap::{Bitmap, MutableBitmap};
44
use arrow::compute::utils::{combine_validities_and, combine_validities_and_not};
55
use polars_compute::if_then_else::{if_then_else_validity, IfThenElseKernel};
66

@@ -216,7 +216,10 @@ impl ChunkZip<StructType> for StructChunked {
216216
mask: &BooleanChunked,
217217
other: &ChunkedArray<StructType>,
218218
) -> PolarsResult<ChunkedArray<StructType>> {
219-
let length = self.length.max(mask.length).max(other.length);
219+
let min_length = self.length.min(mask.length).min(other.length);
220+
let max_length = self.length.max(mask.length).max(other.length);
221+
222+
let length = if min_length == 0 { 0 } else { max_length };
220223

221224
debug_assert!(self.length == 1 || self.length == length);
222225
debug_assert!(mask.length == 1 || mask.length == length);
@@ -227,6 +230,26 @@ impl ChunkZip<StructType> for StructChunked {
227230
let mut if_true: Cow<ChunkedArray<StructType>> = Cow::Borrowed(self);
228231
let mut if_false: Cow<ChunkedArray<StructType>> = Cow::Borrowed(other);
229232

233+
// Special case. In this case, we know what to do.
234+
// @TODO: Optimization. If all mask values are the same, select one of the two.
235+
if mask.length == 1 {
236+
// pl.when(None) <=> pl.when(False)
237+
let is_true = mask.get(0).unwrap_or(false);
238+
return Ok(if is_true && self.length == 1 {
239+
self.new_from_index(0, length)
240+
} else if is_true {
241+
self.clone()
242+
} else if other.length == 1 {
243+
let mut s = other.new_from_index(0, length);
244+
s.rename(self.name().clone());
245+
s
246+
} else {
247+
let mut s = other.clone();
248+
s.rename(self.name().clone());
249+
s
250+
});
251+
}
252+
230253
// align_chunks_ternary can only align chunks if:
231254
// - Each chunkedarray only has 1 chunk
232255
// - Each chunkedarray has an equal length (i.e. is broadcasted)
@@ -235,21 +258,6 @@ impl ChunkZip<StructType> for StructChunked {
235258
let needs_broadcast =
236259
if_true.chunks().len() > 1 || if_false.chunks().len() > 1 || mask.chunks().len() > 1;
237260
if needs_broadcast && length > 1 {
238-
// Special case. In this case, we know what to do.
239-
if mask.length == 1 {
240-
// pl.when(None) <=> pl.when(False)
241-
let is_true = mask.get(0).unwrap_or(false);
242-
return Ok(if is_true && self.length == 1 {
243-
self.new_from_index(0, length)
244-
} else if is_true {
245-
self.clone()
246-
} else if other.length == 1 {
247-
other.new_from_index(0, length)
248-
} else {
249-
other.clone()
250-
});
251-
}
252-
253261
if self.length == 1 {
254262
let broadcasted = self.new_from_index(0, length);
255263
if_true = Cow::Owned(broadcasted);
@@ -288,70 +296,226 @@ impl ChunkZip<StructType> for StructChunked {
288296

289297
let mut out = StructChunked::from_series(self.name().clone(), fields.iter())?;
290298

291-
// Zip the validities.
292-
if (l.null_count + r.null_count) > 0 {
293-
let validities = l
294-
.chunks()
295-
.iter()
296-
.zip(r.chunks())
297-
.map(|(l, r)| (l.validity(), r.validity()));
298-
299-
fn broadcast(v: Option<&Bitmap>, arr: &ArrayRef) -> Bitmap {
300-
if v.unwrap().get(0).unwrap() {
301-
Bitmap::new_with_value(true, arr.len())
302-
} else {
303-
Bitmap::new_zeroed(arr.len())
299+
fn rechunk_bitmaps(
300+
total_length: usize,
301+
iter: impl Iterator<Item = (usize, Option<Bitmap>)>,
302+
) -> Option<Bitmap> {
303+
let mut rechunked_length = 0;
304+
let mut rechunked_validity = None;
305+
for (chunk_length, validity) in iter {
306+
if let Some(validity) = validity {
307+
if validity.unset_bits() > 0 {
308+
rechunked_validity
309+
.get_or_insert_with(|| {
310+
let mut bm = MutableBitmap::with_capacity(total_length);
311+
bm.extend_constant(rechunked_length, true);
312+
bm
313+
})
314+
.extend_from_bitmap(&validity);
315+
}
304316
}
317+
318+
rechunked_length += chunk_length;
305319
}
306320

307-
// # SAFETY
308-
// We don't modify the length and update the null count.
309-
unsafe {
310-
for ((arr, (lv, rv)), mask) in out
311-
.chunks_mut()
312-
.iter_mut()
313-
.zip(validities)
314-
.zip(mask.downcast_iter())
315-
{
316-
// TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating.
317-
let (lv, rv) = match (lv.map(|b| b.len()), rv.map(|b| b.len())) {
318-
(Some(1), Some(1)) if arr.len() != 1 => {
319-
let lv = broadcast(lv, arr);
320-
let rv = broadcast(rv, arr);
321-
(Some(lv), Some(rv))
322-
},
323-
(Some(a), Some(b)) if a == b => (lv.cloned(), rv.cloned()),
324-
(Some(1), _) => {
325-
let lv = broadcast(lv, arr);
326-
(Some(lv), rv.cloned())
327-
},
328-
(_, Some(1)) => {
329-
let rv = broadcast(rv, arr);
330-
(lv.cloned(), Some(rv))
331-
},
332-
(None, Some(_)) | (Some(_), None) | (None, None) => {
333-
(lv.cloned(), rv.cloned())
334-
},
335-
(Some(a), Some(b)) => {
336-
polars_bail!(InvalidOperation: "got different sizes in 'zip' operation, got length: {a} and {b}")
337-
},
338-
};
321+
if let Some(rechunked_validity) = rechunked_validity.as_mut() {
322+
rechunked_validity.extend_constant(total_length - rechunked_validity.len(), true);
323+
}
324+
325+
rechunked_validity.map(MutableBitmap::freeze)
326+
}
339327

340-
// broadcast mask
341-
let validity = if mask.len() != arr.len() && mask.len() == 1 {
342-
if mask.get(0).unwrap() {
343-
lv
328+
// Zip the validities.
329+
//
330+
// We need to take two things into account:
331+
// 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`.
332+
// 2. `l` and `r` might still need to be broadcasted.
333+
if (l.null_count + r.null_count) > 0 {
334+
// Create one validity mask that spans the entirety of out.
335+
let rechunked_validity = match (l.len(), r.len()) {
336+
(1, 1) if length != 1 => match (l.null_count() == 0, r.null_count() == 0) {
337+
(true, true) => None,
338+
(true, false) => {
339+
if mask.chunks().len() == 1 {
340+
let m = mask.chunks()[0]
341+
.as_any()
342+
.downcast_ref::<BooleanArray>()
343+
.unwrap()
344+
.values();
345+
Some(!m)
344346
} else {
345-
rv
347+
rechunk_bitmaps(
348+
length,
349+
mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))),
350+
)
346351
}
352+
},
353+
(false, true) => {
354+
if mask.chunks().len() == 1 {
355+
let m = mask.chunks()[0]
356+
.as_any()
357+
.downcast_ref::<BooleanArray>()
358+
.unwrap()
359+
.values();
360+
Some(m.clone())
361+
} else {
362+
rechunk_bitmaps(
363+
length,
364+
mask.downcast_iter()
365+
.map(|m| (m.len(), Some(m.values().clone()))),
366+
)
367+
}
368+
},
369+
(false, false) => Some(Bitmap::new_zeroed(length)),
370+
},
371+
(1, _) if length != 1 => {
372+
debug_assert!(r
373+
.chunk_lengths()
374+
.zip(mask.chunk_lengths())
375+
.all(|(r, m)| r == m));
376+
377+
let combine = if l.null_count() == 0 {
378+
|r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or_not(r, m))
347379
} else {
348-
if_then_else_validity(mask.values(), lv.as_ref(), rv.as_ref())
380+
|r: Option<&Bitmap>, m: &Bitmap| {
381+
Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and(r, m)))
382+
}
349383
};
350384

351-
*arr = arr.with_validity(validity);
385+
if r.chunks().len() == 1 {
386+
let r = r.chunks()[0].validity();
387+
let m = mask.chunks()[0]
388+
.as_any()
389+
.downcast_ref::<BooleanArray>()
390+
.unwrap()
391+
.values();
392+
393+
let validity = combine(r, m);
394+
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
395+
} else {
396+
rechunk_bitmaps(
397+
length,
398+
r.chunks()
399+
.iter()
400+
.zip(mask.downcast_iter())
401+
.map(|(chunk, mask)| {
402+
(mask.len(), combine(chunk.validity(), mask.values()))
403+
}),
404+
)
405+
}
406+
},
407+
(_, 1) if length != 1 => {
408+
debug_assert!(l
409+
.chunk_lengths()
410+
.zip(mask.chunk_lengths())
411+
.all(|(l, m)| l == m));
412+
413+
let combine = if r.null_count() == 0 {
414+
|r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m))
415+
} else {
416+
|r: Option<&Bitmap>, m: &Bitmap| {
417+
Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m)))
418+
}
419+
};
420+
421+
if l.chunks().len() == 1 {
422+
let l = l.chunks()[0].validity();
423+
let m = mask.chunks()[0]
424+
.as_any()
425+
.downcast_ref::<BooleanArray>()
426+
.unwrap()
427+
.values();
428+
429+
let validity = combine(l, m);
430+
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
431+
} else {
432+
rechunk_bitmaps(
433+
length,
434+
l.chunks()
435+
.iter()
436+
.zip(mask.downcast_iter())
437+
.map(|(chunk, mask)| {
438+
(mask.len(), combine(chunk.validity(), mask.values()))
439+
}),
440+
)
441+
}
442+
},
443+
(_, _) => {
444+
debug_assert!(l
445+
.chunk_lengths()
446+
.zip(r.chunk_lengths())
447+
.all(|(l, r)| l == r));
448+
debug_assert!(l
449+
.chunk_lengths()
450+
.zip(mask.chunk_lengths())
451+
.all(|(l, r)| l == r));
452+
453+
let validities = l
454+
.chunks()
455+
.iter()
456+
.zip(r.chunks())
457+
.map(|(l, r)| (l.validity(), r.validity()));
458+
459+
rechunk_bitmaps(
460+
length,
461+
validities
462+
.zip(mask.downcast_iter())
463+
.map(|((lv, rv), mask)| {
464+
(mask.len(), if_then_else_validity(mask.values(), lv, rv))
465+
}),
466+
)
467+
},
468+
};
469+
470+
// Apply the validity spreading over the chunks of out.
471+
if let Some(mut rechunked_validity) = rechunked_validity {
472+
assert_eq!(rechunked_validity.len(), out.len());
473+
474+
let num_chunks = out.chunks().len();
475+
let null_count = rechunked_validity.unset_bits();
476+
477+
// SAFETY: We do not change the lengths of the chunks and we update the null_count
478+
// afterwards.
479+
let chunks = unsafe { out.chunks_mut() };
480+
481+
if num_chunks == 1 {
482+
chunks[0] = chunks[0].with_validity(Some(rechunked_validity));
483+
} else {
484+
for chunk in chunks {
485+
let chunk_len = chunk.len();
486+
let chunk_validity;
487+
488+
// SAFETY: We know that rechunked_validity.len() == out.len()
489+
(chunk_validity, rechunked_validity) =
490+
unsafe { rechunked_validity.split_at_unchecked(chunk_len) };
491+
*chunk = chunk.with_validity(
492+
(chunk_validity.unset_bits() > 0).then_some(chunk_validity),
493+
);
494+
}
495+
}
496+
497+
out.null_count = null_count as IdxSize;
498+
} else {
499+
// SAFETY: We do not change the lengths of the chunks and we update the null_count
500+
// afterwards.
501+
let chunks = unsafe { out.chunks_mut() };
502+
503+
for chunk in chunks {
504+
*chunk = chunk.with_validity(None);
352505
}
506+
507+
out.null_count = 0 as IdxSize;
353508
}
509+
}
510+
511+
if cfg!(debug_assertions) {
512+
let start_length = out.len();
513+
let start_null_count = out.null_count();
514+
354515
out.compute_len();
516+
517+
assert_eq!(start_length, out.len());
518+
assert_eq!(start_null_count, out.null_count());
355519
}
356520
Ok(out)
357521
}

py-polars/tests/unit/datatypes/test_struct.py

+29
Original file line numberDiff line numberDiff line change
@@ -1019,3 +1019,32 @@ def test_struct_group_by_shift_18107() -> None:
10191019
[{"lon": 60, "lat": 50}, {"lon": 70, "lat": 60}, None],
10201020
],
10211021
}
1022+
1023+
1024+
def test_struct_chunked_zip_18119() -> None:
1025+
dtype = pl.Struct({"x": pl.Null})
1026+
1027+
a_dfs = [pl.DataFrame([pl.Series("a", [None] * i, dtype)]) for i in range(5)]
1028+
b_dfs = [pl.DataFrame([pl.Series("b", [None] * i, dtype)]) for i in range(5)]
1029+
mask_dfs = [
1030+
pl.DataFrame([pl.Series("f", [None] * i, pl.Boolean)]) for i in range(5)
1031+
]
1032+
1033+
a = pl.concat([a_dfs[2], a_dfs[2], a_dfs[1]])
1034+
b = pl.concat([b_dfs[4], b_dfs[1]])
1035+
mask = pl.concat([mask_dfs[3], mask_dfs[2]])
1036+
1037+
df = pl.concat([a, b, mask], how="horizontal")
1038+
1039+
assert_frame_equal(
1040+
df.select(pl.when(pl.col.f).then(pl.col.a).otherwise(pl.col.b)),
1041+
pl.DataFrame([pl.Series("a", [None] * 5, dtype)]),
1042+
)
1043+
1044+
1045+
def test_struct_null_zip() -> None:
1046+
df = pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame()
1047+
assert_frame_equal(
1048+
df.select(pl.when(pl.Series([True])).then(pl.col.int).otherwise(pl.col.int)),
1049+
pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame(),
1050+
)

0 commit comments

Comments
 (0)