1
1
use std:: borrow:: Cow ;
2
2
3
- use arrow:: bitmap:: Bitmap ;
3
+ use arrow:: bitmap:: { Bitmap , MutableBitmap } ;
4
4
use arrow:: compute:: utils:: { combine_validities_and, combine_validities_and_not} ;
5
5
use polars_compute:: if_then_else:: { if_then_else_validity, IfThenElseKernel } ;
6
6
@@ -216,7 +216,10 @@ impl ChunkZip<StructType> for StructChunked {
216
216
mask : & BooleanChunked ,
217
217
other : & ChunkedArray < StructType > ,
218
218
) -> PolarsResult < ChunkedArray < StructType > > {
219
- let length = self . length . max ( mask. length ) . max ( other. length ) ;
219
+ let min_length = self . length . min ( mask. length ) . min ( other. length ) ;
220
+ let max_length = self . length . max ( mask. length ) . max ( other. length ) ;
221
+
222
+ let length = if min_length == 0 { 0 } else { max_length } ;
220
223
221
224
debug_assert ! ( self . length == 1 || self . length == length) ;
222
225
debug_assert ! ( mask. length == 1 || mask. length == length) ;
@@ -227,6 +230,26 @@ impl ChunkZip<StructType> for StructChunked {
227
230
let mut if_true: Cow < ChunkedArray < StructType > > = Cow :: Borrowed ( self ) ;
228
231
let mut if_false: Cow < ChunkedArray < StructType > > = Cow :: Borrowed ( other) ;
229
232
233
+ // Special case. In this case, we know what to do.
234
+ // @TODO: Optimization. If all mask values are the same, select one of the two.
235
+ if mask. length == 1 {
236
+ // pl.when(None) <=> pl.when(False)
237
+ let is_true = mask. get ( 0 ) . unwrap_or ( false ) ;
238
+ return Ok ( if is_true && self . length == 1 {
239
+ self . new_from_index ( 0 , length)
240
+ } else if is_true {
241
+ self . clone ( )
242
+ } else if other. length == 1 {
243
+ let mut s = other. new_from_index ( 0 , length) ;
244
+ s. rename ( self . name ( ) . clone ( ) ) ;
245
+ s
246
+ } else {
247
+ let mut s = other. clone ( ) ;
248
+ s. rename ( self . name ( ) . clone ( ) ) ;
249
+ s
250
+ } ) ;
251
+ }
252
+
230
253
// align_chunks_ternary can only align chunks if:
231
254
// - Each chunkedarray only has 1 chunk
232
255
// - Each chunkedarray has an equal length (i.e. is broadcasted)
@@ -235,21 +258,6 @@ impl ChunkZip<StructType> for StructChunked {
235
258
let needs_broadcast =
236
259
if_true. chunks ( ) . len ( ) > 1 || if_false. chunks ( ) . len ( ) > 1 || mask. chunks ( ) . len ( ) > 1 ;
237
260
if needs_broadcast && length > 1 {
238
- // Special case. In this case, we know what to do.
239
- if mask. length == 1 {
240
- // pl.when(None) <=> pl.when(False)
241
- let is_true = mask. get ( 0 ) . unwrap_or ( false ) ;
242
- return Ok ( if is_true && self . length == 1 {
243
- self . new_from_index ( 0 , length)
244
- } else if is_true {
245
- self . clone ( )
246
- } else if other. length == 1 {
247
- other. new_from_index ( 0 , length)
248
- } else {
249
- other. clone ( )
250
- } ) ;
251
- }
252
-
253
261
if self . length == 1 {
254
262
let broadcasted = self . new_from_index ( 0 , length) ;
255
263
if_true = Cow :: Owned ( broadcasted) ;
@@ -288,70 +296,226 @@ impl ChunkZip<StructType> for StructChunked {
288
296
289
297
let mut out = StructChunked :: from_series ( self . name ( ) . clone ( ) , fields. iter ( ) ) ?;
290
298
291
- // Zip the validities.
292
- if ( l. null_count + r. null_count ) > 0 {
293
- let validities = l
294
- . chunks ( )
295
- . iter ( )
296
- . zip ( r. chunks ( ) )
297
- . map ( |( l, r) | ( l. validity ( ) , r. validity ( ) ) ) ;
298
-
299
- fn broadcast ( v : Option < & Bitmap > , arr : & ArrayRef ) -> Bitmap {
300
- if v. unwrap ( ) . get ( 0 ) . unwrap ( ) {
301
- Bitmap :: new_with_value ( true , arr. len ( ) )
302
- } else {
303
- Bitmap :: new_zeroed ( arr. len ( ) )
299
+ fn rechunk_bitmaps (
300
+ total_length : usize ,
301
+ iter : impl Iterator < Item = ( usize , Option < Bitmap > ) > ,
302
+ ) -> Option < Bitmap > {
303
+ let mut rechunked_length = 0 ;
304
+ let mut rechunked_validity = None ;
305
+ for ( chunk_length, validity) in iter {
306
+ if let Some ( validity) = validity {
307
+ if validity. unset_bits ( ) > 0 {
308
+ rechunked_validity
309
+ . get_or_insert_with ( || {
310
+ let mut bm = MutableBitmap :: with_capacity ( total_length) ;
311
+ bm. extend_constant ( rechunked_length, true ) ;
312
+ bm
313
+ } )
314
+ . extend_from_bitmap ( & validity) ;
315
+ }
304
316
}
317
+
318
+ rechunked_length += chunk_length;
305
319
}
306
320
307
- // # SAFETY
308
- // We don't modify the length and update the null count.
309
- unsafe {
310
- for ( ( arr, ( lv, rv) ) , mask) in out
311
- . chunks_mut ( )
312
- . iter_mut ( )
313
- . zip ( validities)
314
- . zip ( mask. downcast_iter ( ) )
315
- {
316
- // TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating.
317
- let ( lv, rv) = match ( lv. map ( |b| b. len ( ) ) , rv. map ( |b| b. len ( ) ) ) {
318
- ( Some ( 1 ) , Some ( 1 ) ) if arr. len ( ) != 1 => {
319
- let lv = broadcast ( lv, arr) ;
320
- let rv = broadcast ( rv, arr) ;
321
- ( Some ( lv) , Some ( rv) )
322
- } ,
323
- ( Some ( a) , Some ( b) ) if a == b => ( lv. cloned ( ) , rv. cloned ( ) ) ,
324
- ( Some ( 1 ) , _) => {
325
- let lv = broadcast ( lv, arr) ;
326
- ( Some ( lv) , rv. cloned ( ) )
327
- } ,
328
- ( _, Some ( 1 ) ) => {
329
- let rv = broadcast ( rv, arr) ;
330
- ( lv. cloned ( ) , Some ( rv) )
331
- } ,
332
- ( None , Some ( _) ) | ( Some ( _) , None ) | ( None , None ) => {
333
- ( lv. cloned ( ) , rv. cloned ( ) )
334
- } ,
335
- ( Some ( a) , Some ( b) ) => {
336
- polars_bail ! ( InvalidOperation : "got different sizes in 'zip' operation, got length: {a} and {b}" )
337
- } ,
338
- } ;
321
+ if let Some ( rechunked_validity) = rechunked_validity. as_mut ( ) {
322
+ rechunked_validity. extend_constant ( total_length - rechunked_validity. len ( ) , true ) ;
323
+ }
324
+
325
+ rechunked_validity. map ( MutableBitmap :: freeze)
326
+ }
339
327
340
- // broadcast mask
341
- let validity = if mask. len ( ) != arr. len ( ) && mask. len ( ) == 1 {
342
- if mask. get ( 0 ) . unwrap ( ) {
343
- lv
328
+ // Zip the validities.
329
+ //
330
+ // We need to take two things into account:
331
+ // 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`.
332
+ // 2. `l` and `r` might still need to be broadcasted.
333
+ if ( l. null_count + r. null_count ) > 0 {
334
+ // Create one validity mask that spans the entirety of out.
335
+ let rechunked_validity = match ( l. len ( ) , r. len ( ) ) {
336
+ ( 1 , 1 ) if length != 1 => match ( l. null_count ( ) == 0 , r. null_count ( ) == 0 ) {
337
+ ( true , true ) => None ,
338
+ ( true , false ) => {
339
+ if mask. chunks ( ) . len ( ) == 1 {
340
+ let m = mask. chunks ( ) [ 0 ]
341
+ . as_any ( )
342
+ . downcast_ref :: < BooleanArray > ( )
343
+ . unwrap ( )
344
+ . values ( ) ;
345
+ Some ( !m)
344
346
} else {
345
- rv
347
+ rechunk_bitmaps (
348
+ length,
349
+ mask. downcast_iter ( ) . map ( |m| ( m. len ( ) , Some ( !m. values ( ) ) ) ) ,
350
+ )
346
351
}
352
+ } ,
353
+ ( false , true ) => {
354
+ if mask. chunks ( ) . len ( ) == 1 {
355
+ let m = mask. chunks ( ) [ 0 ]
356
+ . as_any ( )
357
+ . downcast_ref :: < BooleanArray > ( )
358
+ . unwrap ( )
359
+ . values ( ) ;
360
+ Some ( m. clone ( ) )
361
+ } else {
362
+ rechunk_bitmaps (
363
+ length,
364
+ mask. downcast_iter ( )
365
+ . map ( |m| ( m. len ( ) , Some ( m. values ( ) . clone ( ) ) ) ) ,
366
+ )
367
+ }
368
+ } ,
369
+ ( false , false ) => Some ( Bitmap :: new_zeroed ( length) ) ,
370
+ } ,
371
+ ( 1 , _) if length != 1 => {
372
+ debug_assert ! ( r
373
+ . chunk_lengths( )
374
+ . zip( mask. chunk_lengths( ) )
375
+ . all( |( r, m) | r == m) ) ;
376
+
377
+ let combine = if l. null_count ( ) == 0 {
378
+ |r : Option < & Bitmap > , m : & Bitmap | r. map ( |r| arrow:: bitmap:: or_not ( r, m) )
347
379
} else {
348
- if_then_else_validity ( mask. values ( ) , lv. as_ref ( ) , rv. as_ref ( ) )
380
+ |r : Option < & Bitmap > , m : & Bitmap | {
381
+ Some ( r. map_or_else ( || m. clone ( ) , |r| arrow:: bitmap:: and ( r, m) ) )
382
+ }
349
383
} ;
350
384
351
- * arr = arr. with_validity ( validity) ;
385
+ if r. chunks ( ) . len ( ) == 1 {
386
+ let r = r. chunks ( ) [ 0 ] . validity ( ) ;
387
+ let m = mask. chunks ( ) [ 0 ]
388
+ . as_any ( )
389
+ . downcast_ref :: < BooleanArray > ( )
390
+ . unwrap ( )
391
+ . values ( ) ;
392
+
393
+ let validity = combine ( r, m) ;
394
+ validity. and_then ( |v| ( v. unset_bits ( ) > 0 ) . then_some ( v) )
395
+ } else {
396
+ rechunk_bitmaps (
397
+ length,
398
+ r. chunks ( )
399
+ . iter ( )
400
+ . zip ( mask. downcast_iter ( ) )
401
+ . map ( |( chunk, mask) | {
402
+ ( mask. len ( ) , combine ( chunk. validity ( ) , mask. values ( ) ) )
403
+ } ) ,
404
+ )
405
+ }
406
+ } ,
407
+ ( _, 1 ) if length != 1 => {
408
+ debug_assert ! ( l
409
+ . chunk_lengths( )
410
+ . zip( mask. chunk_lengths( ) )
411
+ . all( |( l, m) | l == m) ) ;
412
+
413
+ let combine = if r. null_count ( ) == 0 {
414
+ |r : Option < & Bitmap > , m : & Bitmap | r. map ( |r| arrow:: bitmap:: or ( r, m) )
415
+ } else {
416
+ |r : Option < & Bitmap > , m : & Bitmap | {
417
+ Some ( r. map_or_else ( || m. clone ( ) , |r| arrow:: bitmap:: and_not ( r, m) ) )
418
+ }
419
+ } ;
420
+
421
+ if l. chunks ( ) . len ( ) == 1 {
422
+ let l = l. chunks ( ) [ 0 ] . validity ( ) ;
423
+ let m = mask. chunks ( ) [ 0 ]
424
+ . as_any ( )
425
+ . downcast_ref :: < BooleanArray > ( )
426
+ . unwrap ( )
427
+ . values ( ) ;
428
+
429
+ let validity = combine ( l, m) ;
430
+ validity. and_then ( |v| ( v. unset_bits ( ) > 0 ) . then_some ( v) )
431
+ } else {
432
+ rechunk_bitmaps (
433
+ length,
434
+ l. chunks ( )
435
+ . iter ( )
436
+ . zip ( mask. downcast_iter ( ) )
437
+ . map ( |( chunk, mask) | {
438
+ ( mask. len ( ) , combine ( chunk. validity ( ) , mask. values ( ) ) )
439
+ } ) ,
440
+ )
441
+ }
442
+ } ,
443
+ ( _, _) => {
444
+ debug_assert ! ( l
445
+ . chunk_lengths( )
446
+ . zip( r. chunk_lengths( ) )
447
+ . all( |( l, r) | l == r) ) ;
448
+ debug_assert ! ( l
449
+ . chunk_lengths( )
450
+ . zip( mask. chunk_lengths( ) )
451
+ . all( |( l, r) | l == r) ) ;
452
+
453
+ let validities = l
454
+ . chunks ( )
455
+ . iter ( )
456
+ . zip ( r. chunks ( ) )
457
+ . map ( |( l, r) | ( l. validity ( ) , r. validity ( ) ) ) ;
458
+
459
+ rechunk_bitmaps (
460
+ length,
461
+ validities
462
+ . zip ( mask. downcast_iter ( ) )
463
+ . map ( |( ( lv, rv) , mask) | {
464
+ ( mask. len ( ) , if_then_else_validity ( mask. values ( ) , lv, rv) )
465
+ } ) ,
466
+ )
467
+ } ,
468
+ } ;
469
+
470
+ // Apply the validity spreading over the chunks of out.
471
+ if let Some ( mut rechunked_validity) = rechunked_validity {
472
+ assert_eq ! ( rechunked_validity. len( ) , out. len( ) ) ;
473
+
474
+ let num_chunks = out. chunks ( ) . len ( ) ;
475
+ let null_count = rechunked_validity. unset_bits ( ) ;
476
+
477
+ // SAFETY: We do not change the lengths of the chunks and we update the null_count
478
+ // afterwards.
479
+ let chunks = unsafe { out. chunks_mut ( ) } ;
480
+
481
+ if num_chunks == 1 {
482
+ chunks[ 0 ] = chunks[ 0 ] . with_validity ( Some ( rechunked_validity) ) ;
483
+ } else {
484
+ for chunk in chunks {
485
+ let chunk_len = chunk. len ( ) ;
486
+ let chunk_validity;
487
+
488
+ // SAFETY: We know that rechunked_validity.len() == out.len()
489
+ ( chunk_validity, rechunked_validity) =
490
+ unsafe { rechunked_validity. split_at_unchecked ( chunk_len) } ;
491
+ * chunk = chunk. with_validity (
492
+ ( chunk_validity. unset_bits ( ) > 0 ) . then_some ( chunk_validity) ,
493
+ ) ;
494
+ }
495
+ }
496
+
497
+ out. null_count = null_count as IdxSize ;
498
+ } else {
499
+ // SAFETY: We do not change the lengths of the chunks and we update the null_count
500
+ // afterwards.
501
+ let chunks = unsafe { out. chunks_mut ( ) } ;
502
+
503
+ for chunk in chunks {
504
+ * chunk = chunk. with_validity ( None ) ;
352
505
}
506
+
507
+ out. null_count = 0 as IdxSize ;
353
508
}
509
+ }
510
+
511
+ if cfg ! ( debug_assertions) {
512
+ let start_length = out. len ( ) ;
513
+ let start_null_count = out. null_count ( ) ;
514
+
354
515
out. compute_len ( ) ;
516
+
517
+ assert_eq ! ( start_length, out. len( ) ) ;
518
+ assert_eq ! ( start_null_count, out. null_count( ) ) ;
355
519
}
356
520
Ok ( out)
357
521
}
0 commit comments