diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index d0f4913b38c7..e88417e74daa 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -239,7 +239,12 @@ mod test { builder.append_null(); let out = builder.finish(); - let out = out.explode(false).unwrap(); + let out = out + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + .unwrap(); assert_eq!(out.len(), 7); assert_eq!(out.get(6).unwrap(), AnyValue::Null); } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 913ef3fc5c8a..fcbc9a6df75d 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -630,6 +630,27 @@ impl ListChunked { } } + pub fn has_empty_lists(&self) -> bool { + for arr in self.downcast_iter() { + if arr.is_empty() { + continue; + } + + if match arr.validity() { + None => arr.offsets().lengths().any(|l| l == 0), + Some(validity) => arr + .offsets() + .lengths() + .enumerate() + .any(|(i, l)| l == 0 && unsafe { validity.get_bit_unchecked(i) }), + } { + return true; + } + } + + false + } + pub fn has_masked_out_values(&self) -> bool { for arr in self.downcast_iter() { if arr.is_empty() { diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 86508ed67c4d..8714cc0383d5 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -8,7 +8,7 @@ use crate::prelude::*; use crate::series::implementations::null::NullChunked; pub(crate) trait ExplodeByOffsets { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series; + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series; } unsafe fn unset_nulls( @@ -34,7 +34,7 @@ impl ExplodeByOffsets for ChunkedArray where T: PolarsIntegerType, { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series { + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series { debug_assert_eq!(self.chunks.len(), 1); let arr = self.downcast_iter().next().unwrap(); @@ -67,7 +67,7 @@ where for &o in &offsets[1..] { let o = o as usize; - if !skip_empty && o == last { + if options.empty_as_null && o == last { if start != last { #[cfg(debug_assertions)] new_values.extend_from_slice(&values[start..last]); @@ -114,7 +114,7 @@ where } else { for &o in &offsets[1..] { let o = o as usize; - if !skip_empty && o == last { + if options.empty_as_null && o == last { if start != last { unsafe { new_values.extend_from_slice(values.get_unchecked(start..last)) }; } @@ -150,31 +150,31 @@ where } impl ExplodeByOffsets for Float32Chunked { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series { + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series { self.apply_as_ints(|s| { let ca = s.u32().unwrap(); - ca.explode_by_offsets(offsets, skip_empty) + ca.explode_by_offsets(offsets, options) }) } } impl ExplodeByOffsets for Float64Chunked { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series { + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series { self.apply_as_ints(|s| { let ca = s.u64().unwrap(); - ca.explode_by_offsets(offsets, skip_empty) + ca.explode_by_offsets(offsets, options) }) } } impl ExplodeByOffsets for NullChunked { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series { + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series { let mut last_offset = offsets[0]; let mut len = 0; for &offset in &offsets[1..] { // If offset == last_offset we have an empty list and a new row is inserted, // therefore we always increase at least 1. - len += std::cmp::max(offset - last_offset, i64::from(!skip_empty)) as usize; + len += std::cmp::max(offset - last_offset, i64::from(options.empty_as_null)) as usize; last_offset = offset; } NullChunked::new(self.name.clone(), len).into_series() @@ -182,7 +182,7 @@ impl ExplodeByOffsets for NullChunked { } impl ExplodeByOffsets for BooleanChunked { - fn explode_by_offsets(&self, offsets: &[i64], skip_empty: bool) -> Series { + fn explode_by_offsets(&self, offsets: &[i64], options: ExplodeOptions) -> Series { debug_assert_eq!(self.chunks.len(), 1); let arr = self.downcast_iter().next().unwrap(); @@ -193,7 +193,7 @@ impl ExplodeByOffsets for BooleanChunked { let mut last = start; for &o in &offsets[1..] { let o = o as usize; - if !skip_empty && o == last { + if options.empty_as_null && o == last { if start != last { let vals = arr.slice_typed(start, last - start); @@ -283,12 +283,18 @@ mod test { assert!(ca._can_fast_explode()); // normal explode - let exploded = ca.explode(false)?; + let exploded = ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out: Vec<_> = exploded.i32()?.into_no_null_iter().collect(); assert_eq!(out, &[1, 2, 3, 3, 1, 2]); // sliced explode - let exploded = ca.slice(0, 1).explode(false)?; + let exploded = ca.slice(0, 1).explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out: Vec<_> = exploded.i32()?.into_no_null_iter().collect(); assert_eq!(out, &[1, 2, 3, 3]); @@ -310,7 +316,10 @@ mod test { .unwrap(); let ca = builder.finish(); - let exploded = ca.explode(false)?; + let exploded = ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(exploded.i32()?), &[Some(1), Some(2), None, Some(3)] @@ -335,7 +344,10 @@ mod test { .unwrap(); let ca = builder.finish(); - let exploded = ca.explode(false)?; + let exploded = ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(exploded.i32()?), &[Some(1), None, Some(2), None, Some(3), Some(4)] @@ -381,7 +393,10 @@ mod test { .unwrap(); let ca = builder.finish(); - let exploded = ca.explode(false)?; + let exploded = ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(exploded.str()?), &[Some("abc"), None, Some("de"), None, Some("fg"), None] @@ -406,7 +421,10 @@ mod test { .unwrap(); let ca = builder.finish(); - let exploded = ca.explode(false)?; + let exploded = ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(exploded.bool()?), &[Some(true), None, Some(false), None, Some(true), Some(true)] diff --git a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs index 429ca35d7b52..04dbe2f053b9 100644 --- a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs +++ b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs @@ -9,7 +9,7 @@ impl ListChunked { values: ArrayRef, offsets: &[i64], offsets_buf: OffsetsBuffer, - skip_empty: bool, + options: ExplodeOptions, ) -> (Series, OffsetsBuffer) { // SAFETY: inner_dtype should be correct let values = unsafe { @@ -25,16 +25,16 @@ impl ListChunked { let mut values = match values.dtype() { DataType::Boolean => { let t = values.bool().unwrap(); - ExplodeByOffsets::explode_by_offsets(t, offsets, skip_empty).into_series() + ExplodeByOffsets::explode_by_offsets(t, offsets, options).into_series() }, DataType::Null => { let t = values.null().unwrap(); - ExplodeByOffsets::explode_by_offsets(t, offsets, skip_empty).into_series() + ExplodeByOffsets::explode_by_offsets(t, offsets, options).into_series() }, dtype => { with_match_physical_numeric_polars_type!(dtype, |$T| { let t: &ChunkedArray<$T> = values.as_ref().as_ref(); - ExplodeByOffsets::explode_by_offsets(t, offsets, skip_empty).into_series() + ExplodeByOffsets::explode_by_offsets(t, offsets, options).into_series() }) }, }; @@ -55,7 +55,10 @@ impl ChunkExplode for ListChunked { Ok(offsets) } - fn explode_and_offsets(&self, skip_empty: bool) -> PolarsResult<(Series, OffsetsBuffer)> { + fn explode_and_offsets( + &self, + options: ExplodeOptions, + ) -> PolarsResult<(Series, OffsetsBuffer)> { // A list array's memory layout is actually already 'exploded', so we can just take the // values array of the list. And we also return a slice of the offsets. This slice can be // used to find the old list layout or indexes to expand a DataFrame in the same manner as @@ -66,7 +69,10 @@ impl ChunkExplode for ListChunked { let offsets = listarr.offsets().as_slice(); let mut values = listarr.values().clone(); - let (mut s, offsets) = if ca._can_fast_explode() { + let (mut s, offsets) = if ca._can_fast_explode() + && (!options.keep_nulls || !ca.has_nulls()) + && (!options.empty_as_null || !ca.has_empty_lists()) + { // ensure that the value array is sliced // as a list only slices its offsets on a slice operation @@ -112,7 +118,7 @@ impl ChunkExplode for ListChunked { let inner_phys = self.inner_dtype().to_physical(); if inner_phys.is_primitive_numeric() || inner_phys.is_null() || inner_phys.is_bool() { - return Ok(self.explode_specialized(values, offsets, offsets_buf, skip_empty)); + return Ok(self.explode_specialized(values, offsets, offsets_buf, options)); } // Use gather let mut indices = @@ -127,7 +133,7 @@ impl ChunkExplode for ListChunked { let start = previous as IdxSize; let end = offset as IdxSize; - if !skip_empty && len == 0 { + if options.empty_as_null && len == 0 { indices.push_null(); } else { indices.extend_trusted_len_values(start..end); @@ -156,13 +162,13 @@ impl ChunkExplode for ListChunked { // SAFETY: we are within bounds if unsafe { validity.get_bit_unchecked(i) } { // explode expects null value if sublist is empty. - if !skip_empty && len == 0 { + if options.empty_as_null && len == 0 { indices.push_null(); } else { indices.extend_trusted_len_values(start..end); } current_offset += len; - } else { + } else if options.keep_nulls { indices.push_null(); } previous = offset; @@ -236,7 +242,31 @@ impl ChunkExplode for ArrayChunked { Ok(offsets) } - fn explode_and_offsets(&self, _skip_empty: bool) -> PolarsResult<(Series, OffsetsBuffer)> { + fn explode_and_offsets( + &self, + options: ExplodeOptions, + ) -> PolarsResult<(Series, OffsetsBuffer)> { + if self.width() == 0 { + let mut num_nulls = 0; + if options.empty_as_null { + num_nulls += self.len() - self.null_count(); + } + if options.keep_nulls { + num_nulls += self.null_count(); + } + let offsets = (0..num_nulls as i64 + 1).collect::>(); + // SAFETY: monotonically increasing + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + let s = Column::new_scalar( + self.name().clone(), + Scalar::null(self.inner_dtype().clone()), + num_nulls, + ) + .take_materialized_series(); + + return Ok((s, offsets)); + } + let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); // fast-path for non-null array. @@ -278,7 +308,7 @@ impl ChunkExplode for ArrayChunked { let end = start + width as IdxSize; indices.extend_trusted_len_values(start..end); current_offset += width as i64; - } else { + } else if options.keep_nulls { indices.push_null(); } offsets.push(current_offset); diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 005840c72fac..0a3664acd2a3 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -81,13 +81,24 @@ pub trait ChunkAnyValue { fn get_any_value(&self, index: usize) -> PolarsResult>; } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] +pub struct ExplodeOptions { + pub empty_as_null: bool, + pub keep_nulls: bool, +} + /// Explode/flatten a List or String Series pub trait ChunkExplode { - fn explode(&self, skip_empty: bool) -> PolarsResult { - self.explode_and_offsets(skip_empty).map(|t| t.0) + fn explode(&self, options: ExplodeOptions) -> PolarsResult { + self.explode_and_offsets(options).map(|t| t.0) } fn offsets(&self) -> PolarsResult>; - fn explode_and_offsets(&self, skip_empty: bool) -> PolarsResult<(Series, OffsetsBuffer)>; + fn explode_and_offsets( + &self, + options: ExplodeOptions, + ) -> PolarsResult<(Series, OffsetsBuffer)>; } pub trait ChunkBytes { diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 9e2310d14ffe..d2b6a0e3c354 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -1250,9 +1250,9 @@ impl Column { } } - pub fn explode(&self, skip_empty: bool) -> PolarsResult { + pub fn explode(&self, options: ExplodeOptions) -> PolarsResult { self.as_materialized_series() - .explode(skip_empty) + .explode(options) .map(Column::from) } pub fn implode(&self) -> PolarsResult { diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index 8d7e7ae762be..622243a14781 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -11,9 +11,15 @@ use crate::series::IsSorted; fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer)> { match series.dtype() { - DataType::List(_) => series.list().unwrap().explode_and_offsets(false), + DataType::List(_) => series.list().unwrap().explode_and_offsets(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }), #[cfg(feature = "dtype-array")] - DataType::Array(_, _) => series.array().unwrap().explode_and_offsets(false), + DataType::Array(_, _) => series.array().unwrap().explode_and_offsets(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }), _ => polars_bail!(opq = explode, series.dtype()), } } @@ -34,7 +40,10 @@ impl DataFrame { let mut df = self.clone(); if self.is_empty() { for s in &columns { - df.with_column(s.as_materialized_series().explode(false)?)?; + df.with_column(s.as_materialized_series().explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?)?; } return Ok(df); } diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 1e6069eec758..c223efd5a857 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -625,11 +625,11 @@ impl Series { } /// Explode a list Series. This expands every item to a new row.. - pub fn explode(&self, skip_empty: bool) -> PolarsResult { + pub fn explode(&self, options: ExplodeOptions) -> PolarsResult { match self.dtype() { - DataType::List(_) => self.list().unwrap().explode(skip_empty), + DataType::List(_) => self.list().unwrap().explode(options), #[cfg(feature = "dtype-array")] - DataType::Array(_, _) => self.array().unwrap().explode(skip_empty), + DataType::Array(_, _) => self.array().unwrap().explode(options), _ => Ok(self.clone()), } } diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index f59f333dfbaf..a2c1abf2c3c9 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -222,7 +222,10 @@ impl Series { let s = self; let s = if let DataType::List(_) = s.dtype() { - Cow::Owned(s.explode(true)?) + Cow::Owned(s.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?) } else { Cow::Borrowed(s) }; @@ -333,7 +336,14 @@ mod test { let out = s.reshape_list(&dims)?; assert_eq!(out.len(), list_len); assert!(matches!(out.dtype(), DataType::List(_))); - assert_eq!(out.explode(false)?.len(), 4); + assert_eq!( + out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })? + .len(), + 4 + ); } Ok(()) diff --git a/crates/polars-expr/src/dispatch/array.rs b/crates/polars-expr/src/dispatch/array.rs index 3e4ac32967b9..d4923c8612d5 100644 --- a/crates/polars-expr/src/dispatch/array.rs +++ b/crates/polars-expr/src/dispatch/array.rs @@ -1,5 +1,5 @@ use polars_core::error::{PolarsResult, polars_bail, polars_ensure, polars_err}; -use polars_core::prelude::{Column, DataType, IntoColumn, SortOptions}; +use polars_core::prelude::{Column, DataType, ExplodeOptions, IntoColumn, SortOptions}; use polars_ops::prelude::array::ArrayNameSpace; #[cfg(feature = "array_to_struct")] use polars_plan::dsl::DslNameGenerator; @@ -39,7 +39,7 @@ pub fn function_expr_to_udf(func: IRArrayFunction) -> SpecialEq map_as_slice!(count_matches), Shift => map_as_slice!(shift), - Explode { skip_empty } => map_as_slice!(explode, skip_empty), + Explode(options) => map_as_slice!(explode, options), Slice(offset, length) => map!(slice, offset, length), #[cfg(feature = "array_to_struct")] ToStruct(ng) => map!(arr_to_struct, ng.clone()), @@ -200,8 +200,8 @@ pub(super) fn slice(s: &Column, offset: i64, length: i64) -> PolarsResult PolarsResult { - c[0].explode(skip_empty) +fn explode(c: &[Column], options: ExplodeOptions) -> PolarsResult { + c[0].explode(options) } fn concat_arr(args: &[Column]) -> PolarsResult { diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 968f23d719f8..013631e514ca 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -85,7 +85,12 @@ impl ApplyExpr { ca: ListChunked, ) -> PolarsResult> { let c = if self.is_scalar() { - let out = ca.explode(false).unwrap(); + let out = ca + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + .unwrap(); // if the explode doesn't return the same len, it wasn't scalar. polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out); ac.update_groups = UpdateGroups::No; diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 3f199915e543..6f6bc9e6ee6d 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -513,7 +513,12 @@ impl<'a> AggregationContext<'a> { AggState::AggregatedScalar(c) => (c, groups), AggState::LiteralScalar(c) => (c, groups), AggState::AggregatedList(c) => { - let flattened = c.explode(true).unwrap(); + let flattened = c + .explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + }) + .unwrap(); let groups = groups.into_owned(); // unroll the possible flattened state // say we have groups with overlapping windows: @@ -567,7 +572,13 @@ impl<'a> AggregationContext<'a> { } // We should not insert nulls, otherwise the offsets in the groups will not be correct. - Cow::Owned(c.explode(true).unwrap()) + Cow::Owned( + c.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + }) + .unwrap(), + ) }, AggState::AggregatedScalar(c) => Cow::Borrowed(c), AggState::LiteralScalar(c) => Cow::Borrowed(c), diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index d03d63b5b391..874136493ccf 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -412,7 +412,15 @@ impl PhysicalExpr for SortByExpr { // group_by operation - we must ensure that we are as well. if ordered_by_group_operation { let s = ac_in.aggregated(); - ac_in.with_values(s.explode(false).unwrap(), false, None)?; + ac_in.with_values( + s.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + .unwrap(), + false, + None, + )?; } ac_in.with_groups(groups.into_sliceable()); diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index 9177b0f0fe27..2dfd3a44de99 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -67,7 +67,10 @@ fn finish_as_iters<'a>( // Exploded list should be equal to groups length. list_vals_len == ac_truthy.groups.len() { - out = out.explode(false)? + out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })? } ac_truthy.with_agg_state(AggState::AggregatedList(out)); diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index d24b8b332648..bc55e5769bd2 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -505,7 +505,10 @@ impl PhysicalExpr for WindowExpr { let out = if self.phys_function.is_scalar() { ac.get_values().clone() } else { - ac.aggregated().explode(false)? + ac.aggregated().explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })? }; Ok(out.into_column()) }, @@ -513,7 +516,10 @@ impl PhysicalExpr for WindowExpr { // TODO! // investigate if sorted arrays can be return directly let out_column = ac.aggregated(); - let flattened = out_column.explode(false)?; + let flattened = out_column.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; // we extend the lifetime as we must convince the compiler that ac lives // long enough. We drop `GrouBy` when we are done with `ac`. let ac = unsafe { diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 3d6c95f27de0..7649293de15f 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -634,11 +634,11 @@ fn create_physical_expr_inner( expr: node_to_expr(expression, expr_arena), })) }, - Explode { expr, skip_empty } => { + Explode { expr, options } => { let input = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; - let skip_empty = *skip_empty; + let options = *options; let function = SpecialEq::new(Arc::new( - move |c: &mut [polars_core::frame::column::Column]| c[0].explode(skip_empty), + move |c: &mut [polars_core::frame::column::Column]| c[0].explode(options), ) as Arc); let output_field = expr_arena diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index fd624b3dc027..e696038858c1 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -249,7 +249,10 @@ fn test_binary_agg_context_0() -> PolarsResult<()> { .unwrap(); let out = out.column("foo")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.str()?; assert_eq!( Vec::from(out), @@ -293,7 +296,10 @@ fn test_binary_agg_context_1() -> PolarsResult<()> { // [90, 90] // [7, 90] let out = out.column("vals")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.i32()?; assert_eq!( Vec::from(out), @@ -314,7 +320,10 @@ fn test_binary_agg_context_1() -> PolarsResult<()> { // [90, 90] // [90, 7] let out = out.column("vals")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.i32()?; assert_eq!( Vec::from(out), @@ -344,7 +353,10 @@ fn test_binary_agg_context_2() -> PolarsResult<()> { // 3 - [3, 4] = [0, -1] // 5 - [5, 6] = [0, -1] let out = out.column("vals")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.i32()?; assert_eq!( Vec::from(out), @@ -362,7 +374,10 @@ fn test_binary_agg_context_2() -> PolarsResult<()> { // [3, 4] - 3 = [0, 1] // [5, 6] - 5 = [0, 1] let out = out.column("vals")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.i32()?; assert_eq!( Vec::from(out), @@ -449,7 +464,10 @@ fn take_aggregations() -> PolarsResult<()> { .sort(["user"], Default::default()) .collect()?; let s = out.column("ordered")?; - let flat = s.explode(false)?; + let flat = s.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let flat = flat.str()?; let vals = flat.into_no_null_iter().collect::>(); assert_eq!(vals, ["a", "b", "c", "a", "a"]); diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 7f17881599c3..c49e8d7b9a99 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1039,7 +1039,14 @@ fn test_group_by_cum_sum() -> PolarsResult<()> { .collect()?; assert_eq!( - Vec::from(out.column("vals")?.explode(false)?.i32()?), + Vec::from( + out.column("vals")? + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true + })? + .i32()? + ), [1, 5, 11, 3, 12, 20] .iter() .copied() @@ -1301,7 +1308,10 @@ fn test_sort_by() -> PolarsResult<()> { .group_by_stable([col("b")]) .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) .collect()?; - let a = out.column("a")?.explode(false)?; + let a = out.column("a")?.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(a.i32().unwrap()), &[Some(3), Some(1), Some(2), Some(5), Some(4)] @@ -1314,7 +1324,10 @@ fn test_sort_by() -> PolarsResult<()> { .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) .collect()?; - let a = out.column("a")?.explode(false)?; + let a = out.column("a")?.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( Vec::from(a.i32().unwrap()), &[Some(3), Some(1), Some(2), Some(5), Some(4)] @@ -1691,7 +1704,10 @@ fn test_single_ranked_group() -> PolarsResult<()> { .over_with_options(Some([col("group")]), None, WindowMapping::Join)?]) .collect()?; - let out = out.column("value")?.explode(false)?; + let out = out.column("value")?.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.f64()?; assert_eq!( Vec::from(out), @@ -1760,7 +1776,10 @@ fn test_is_in() -> PolarsResult<()> { )]) .collect()?; let out = out.column("cars").unwrap(); - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.bool().unwrap(); assert_eq!( Vec::from(out), @@ -1778,7 +1797,10 @@ fn test_is_in() -> PolarsResult<()> { .collect()?; let out = out.column("cars").unwrap(); - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.bool().unwrap(); assert_eq!( Vec::from(out), diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index 5d0ea37872ee..2dc90567edf7 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -96,7 +96,10 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { let sc = s.min_reduce()?; Ok(sc.into_series(s.name().clone())) })? - .explode(false) + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) .unwrap() .into_series() .cast(dt), @@ -208,7 +211,10 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { let sc = s.max_reduce()?; Ok(sc.into_series(s.name().clone())) })? - .explode(false) + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) .unwrap() .into_series() .cast(dt), diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 51a168254ef2..71d4269c66ce 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -489,7 +489,10 @@ pub trait ListNameSpaceImpl: AsList { list_ca.inner_dtype(), ) } else { - let s = list_ca.explode(false)?; + let s = list_ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; idx_ca .into_iter() .map(|opt_idx| { @@ -503,7 +506,10 @@ pub trait ListNameSpaceImpl: AsList { Ok(out.into_series()) }, (_, 1) => { - let idx_ca = idx_ca.explode(false)?; + let idx_ca = idx_ca.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; use DataType as D; match idx_ca.dtype() { diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index 2edd6b19289b..0944851755ce 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -133,7 +133,10 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars .sum_reduce() .map(|sc| sc.into_series(PlSmallStr::EMPTY)) })? - .explode(false) + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) .unwrap() .into_series() .cast(dt)?, diff --git a/crates/polars-ops/src/chunked_array/strings/find_many.rs b/crates/polars-ops/src/chunked_array/strings/find_many.rs index 83652e91eebe..fbe3d5b8ff81 100644 --- a/crates/polars-ops/src/chunked_array/strings/find_many.rs +++ b/crates/polars-ops/src/chunked_array/strings/find_many.rs @@ -41,7 +41,10 @@ pub fn contains_any( return Ok(BooleanChunked::full_null(ca.name().clone(), ca.len())); } - let patterns = patterns.explode(true)?; + let patterns = patterns.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let patterns = patterns.str()?; let ac = build_ac(patterns, ascii_case_insensitive)?; @@ -88,9 +91,15 @@ pub fn replace_all( return Ok(StringChunked::full_null(ca.name().clone(), ca.len())); } - let patterns = patterns.explode(true)?; + let patterns = patterns.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let patterns = patterns.str()?; - let replace_with = replace_with.explode(true)?; + let replace_with = replace_with.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let replace_with = replace_with.str()?; let replace_with = if replace_with.len() == 1 && patterns.len() > 1 { @@ -165,7 +174,10 @@ pub fn extract_many( }, }, (_, 1) => { - let patterns = patterns.explode(true)?; + let patterns = patterns.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let patterns = patterns.str()?; let ac = build_ac(patterns, ascii_case_insensitive)?; let mut builder = @@ -256,7 +268,10 @@ pub fn find_many( }, }, (_, 1) => { - let patterns = patterns.explode(true)?; + let patterns = patterns.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let patterns = patterns.str()?; let ac = build_ac(patterns, ascii_case_insensitive)?; let mut builder = B::new(ca.name().clone(), ca.len(), ca.len() * 2, DataType::UInt32); diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index a538bb10651f..1b868680bd7e 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -238,7 +238,10 @@ where return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.as_ref().as_ref(); is_in_helper_ca(ca_in, other, nulls_equal) } else { @@ -253,7 +256,10 @@ where return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.as_ref().as_ref(); is_in_helper_ca(ca_in, other, nulls_equal) } else { @@ -312,7 +318,10 @@ fn is_in_binary( return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.binary()?; is_in_helper_ca(ca_in, other, nulls_equal) } else { @@ -327,7 +336,10 @@ fn is_in_binary( return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.binary()?; is_in_helper_ca(ca_in, other, nulls_equal) } else { @@ -380,7 +392,10 @@ fn is_in_boolean( return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.bool()?; is_in_boolean_broadcast(ca_in, other, nulls_equal) } else { @@ -395,7 +410,10 @@ fn is_in_boolean( return Ok(BooleanChunked::full_null(ca_in.name().clone(), ca_in.len())); } - let other = other.explode(true)?; + let other = other.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; let other = other.bool()?; is_in_boolean_broadcast(ca_in, other, nulls_equal) } else { @@ -464,7 +482,10 @@ fn is_in_null(s: &Series, other: &Series, nulls_equal: bool) -> PolarsResult PolarsResult PolarsResult nyi = "`replace` with a replacement pattern per row" ); - let old = old.explode(true)?; - let new = new.explode(true)?; + let old = old.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; + let new = new.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; if old.is_empty() { return Ok(s.clone()); @@ -83,8 +89,14 @@ pub fn replace_or_default( nyi = "`replace_strict` with a replacement pattern per row" ); - let old = old.explode(true)?; - let new = new.explode(true)?; + let old = old.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; + let new = new.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; polars_ensure!( default.len() == s.len() || default.len() == 1, @@ -136,8 +148,14 @@ pub fn replace_strict( nyi = "`replace_strict` with a replacement pattern per row" ); - let old = old.explode(true)?; - let new = new.explode(true)?; + let old = old.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; + let new = new.explode(ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + })?; if old.is_empty() { polars_ensure!( diff --git a/crates/polars-plan/dsl-schema-hashes.json b/crates/polars-plan/dsl-schema-hashes.json index 7ca8d309174f..18c57de8eb33 100644 --- a/crates/polars-plan/dsl-schema-hashes.json +++ b/crates/polars-plan/dsl-schema-hashes.json @@ -3,7 +3,7 @@ "AnonymousColumnsUdf": "5bbddd4f899afa592c318b20bb8d0bdfe2877fa5bf1a63d9cd0da908ac3aec0e", "AnyValue": "f1ae4795a8f61ca45cb81da244dc3b5a266f81b2090fab27fc50f6edf87a86ee", "ArrayDataTypeFunction": "c6089e74d6b54ea7576f21b0bf7d449d60f091243565d245188126f0cd7f1bf6", - "ArrayFunction": "acacf3b4189157c3898113e5d195b05619d7ae727734c4518e170edbf6611e1f", + "ArrayFunction": "b437b9e540cd4400da8a3a013000a7bbe4c48ee5de9e18e3da018362817b492f", "Array_of_PlPath": "539ecfb914d069d118ef07e335fa9ea72a5eff221a9679f577b6753727d30f40", "AsOfOptions": "f61410edcacd7b460cec03b8178870f62e61d37e5d0042c1ccb29543cc24dc08", "AsofStrategy": "777dd1236ad9111d4d0c5b537364eea2722a67f1771d1a49ee52869e15937830", @@ -50,7 +50,8 @@ "EWMOptions": "3997323cf1a48491ab48ed491cabf768954175970f83c0e7899490a58d310322", "Either_PythonObject_or_Schema_for_DataType": "6232a29ef51626d332177544fe80084dbc5451e45087aacafae633c93526ee6e", "EvalVariant": "6f3f2249f963d4b89339a93beace83e0be41310b4779af62ace5d4240013d7d8", - "Expr": "d387bccb6ddbcaa3346f8ac6735e684cf62623f3201cf4847a50e16a37fe10b6", + "ExplodeOptions": "46ef78ccb0ca3a84a96dc69c4bba22790e9adc50a2862a68fa8c58c793c660bf", + "Expr": "a3b713cacabb85744b30f1808ad47bdcb97f83445e6a517abb2e5b7f2cd4ab0f", "ExtraColumnsPolicy": "eb81efadce58eb148e658db4f2b5c1f38155d617431b81121043e9f9c21acd30", "Field": "dd95c2b6d7aa44004b900ef31fcf18e70f862d97488ef46c67b7c64c226b50d8", "FileScanDsl": "aec02dec7ace1d00b449f2f03fe5dc17b2d668cad483a74bc83ad5aee4b14981", diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index b3d5b608fd7d..b1a6fea9ceb8 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -191,11 +191,9 @@ impl ArrayNameSpace { .map_binary(FunctionExpr::ArrayExpr(ArrayFunction::Shift), n) } /// Returns a column with a separate row for every array element. - pub fn explode(self) -> Expr { + pub fn explode(self, options: ExplodeOptions) -> Expr { self.0 - .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Explode { - skip_empty: false, - })) + .map_unary(FunctionExpr::ArrayExpr(ArrayFunction::Explode(options))) } pub fn eval>(self, other: E, as_list: bool) -> Expr { diff --git a/crates/polars-plan/src/dsl/expr/mod.rs b/crates/polars-plan/src/dsl/expr/mod.rs index 57264c8c1a25..85241f889f89 100644 --- a/crates/polars-plan/src/dsl/expr/mod.rs +++ b/crates/polars-plan/src/dsl/expr/mod.rs @@ -137,7 +137,7 @@ pub enum Expr { }, Explode { input: Arc, - skip_empty: bool, + options: ExplodeOptions, }, Filter { input: Arc, @@ -340,8 +340,8 @@ impl Hash for Expr { sort_options.hash(state); }, Expr::Agg(input) => input.hash(state), - Expr::Explode { input, skip_empty } => { - skip_empty.hash(state); + Expr::Explode { input, options } => { + options.hash(state); input.hash(state) }, #[cfg(feature = "dynamic_group_by")] diff --git a/crates/polars-plan/src/dsl/format.rs b/crates/polars-plan/src/dsl/format.rs index 86969b581a80..5d3be43f231f 100644 --- a/crates/polars-plan/src/dsl/format.rs +++ b/crates/polars-plan/src/dsl/format.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::fmt::{self, Write}; use crate::prelude::*; @@ -46,12 +46,20 @@ impl fmt::Debug for Expr { Len => write!(f, "len()"), Explode { input: expr, - skip_empty: false, - } => write!(f, "{expr:?}.explode()"), - Explode { - input: expr, - skip_empty: true, - } => write!(f, "{expr:?}.explode(skip_empty)"), + options, + } => { + write!(f, "{expr:?}.explode(")?; + if !options.empty_as_null { + f.write_str("empty_as_null=false")?; + } + if !options.keep_nulls { + if options.empty_as_null { + f.write_str(", ")?; + } + f.write_str("keep_nulls=false")?; + } + f.write_char(')') + }, Alias(expr, name) => write!(f, "{expr:?}.alias(\"{name}\")"), Column(name) => write!(f, "col(\"{name}\")"), Literal(v) => write!(f, "{v:?}"), diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 02ff1c642ff7..3bbbc31f03a0 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -1,6 +1,6 @@ use std::fmt; -use polars_core::prelude::SortOptions; +use polars_core::prelude::{ExplodeOptions, SortOptions}; use super::FunctionExpr; @@ -37,9 +37,7 @@ pub enum ArrayFunction { #[cfg(feature = "array_count")] CountMatches, Shift, - Explode { - skip_empty: bool, - }, + Explode(ExplodeOptions), Concat, #[cfg(feature = "array_to_struct")] ToStruct(Option), diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index cb79ca2f15b8..25810a42833d 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -207,14 +207,17 @@ impl Expr { /// Alias for `explode`. pub fn flatten(self) -> Self { - self.explode() + self.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) } /// Explode the String/List column. - pub fn explode(self) -> Self { + pub fn explode(self, options: ExplodeOptions) -> Self { Expr::Explode { input: Arc::new(self), - skip_empty: false, + options, } } diff --git a/crates/polars-plan/src/plans/aexpr/builder.rs b/crates/polars-plan/src/plans/aexpr/builder.rs index 2125c2b2c717..7a08f744b8b9 100644 --- a/crates/polars-plan/src/plans/aexpr/builder.rs +++ b/crates/polars-plan/src/plans/aexpr/builder.rs @@ -1,5 +1,5 @@ use polars_core::chunked_array::cast::CastOptions; -use polars_core::prelude::{DataType, SortMultipleOptions, SortOptions}; +use polars_core::prelude::{DataType, ExplodeOptions, SortMultipleOptions, SortOptions}; use polars_core::scalar::Scalar; use polars_utils::IdxSize; use polars_utils::arena::{Arena, Node}; @@ -191,21 +191,11 @@ impl AExprBuilder { ) } - pub fn explode_skip_empty(self, arena: &mut Arena) -> Self { + pub fn explode(self, arena: &mut Arena, options: ExplodeOptions) -> Self { Self::new_from_aexpr( AExpr::Explode { expr: self.node(), - skip_empty: true, - }, - arena, - ) - } - - pub fn explode_null_empty(self, arena: &mut Arena) -> Self { - Self::new_from_aexpr( - AExpr::Explode { - expr: self.node(), - skip_empty: false, + options, }, arena, ) diff --git a/crates/polars-plan/src/plans/aexpr/equality.rs b/crates/polars-plan/src/plans/aexpr/equality.rs index bc564c820650..65a50bc14590 100644 --- a/crates/polars-plan/src/plans/aexpr/equality.rs +++ b/crates/polars-plan/src/plans/aexpr/equality.rs @@ -60,7 +60,7 @@ impl AExpr { // match to be exhaustive. #[rustfmt::skip] let is_equal = match self { - E::Explode { expr: _, skip_empty: l_skip_empty } => matches!(other, E::Explode { expr: _, skip_empty: r_skip_empty } if l_skip_empty == r_skip_empty), + E::Explode { expr: _, options: l_options } => matches!(other, E::Explode { expr: _, options: r_options } if l_options == r_options), E::Column(l_name) => matches!(other, E::Column(r_name) if l_name == r_name), E::Literal(l_lit) => matches!(other, E::Literal(r_lit) if l_lit == r_lit), E::BinaryExpr { left: _, op: l_op, right: _ } => matches!(other, E::BinaryExpr { left: _, op: r_op, right: _ } if l_op == r_op), diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/array.rs b/crates/polars-plan/src/plans/aexpr/function_expr/array.rs index 495c6eae2aa9..d970a60a71fc 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/array.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/array.rs @@ -34,9 +34,7 @@ pub enum IRArrayFunction { #[cfg(feature = "array_count")] CountMatches, Shift, - Explode { - skip_empty: bool, - }, + Explode(ExplodeOptions), Concat, Slice(i64, i64), #[cfg(feature = "array_to_struct")] diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 7a8464c930bd..deeac3a8e515 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -178,7 +178,7 @@ pub enum AExpr { Element, Explode { expr: Node, - skip_empty: bool, + options: ExplodeOptions, }, Column(PlSmallStr), Literal(LiteralValue), diff --git a/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs index 5608ac3f6c6c..cc3477fcf933 100644 --- a/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs +++ b/crates/polars-plan/src/plans/aexpr/predicates/skip_batches.rs @@ -314,6 +314,8 @@ fn aexpr_to_skip_batch_predicate_rec( constant_evaluate(lv_node, arena, schema, 0), ) { (Some(col), Some(_)) => { + use polars_core::prelude::ExplodeOptions; + let dtype = schema.get(col)?; if !does_dtype_have_sufficient_order(dtype) { return None; @@ -328,7 +330,13 @@ fn aexpr_to_skip_batch_predicate_rec( let col = col.clone(); let lv_node = lv_node.into_aexpr_builder(); - let lv_node_exploded = lv_node.explode_skip_empty(arena); + let lv_node_exploded = lv_node.explode( + arena, + ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + }, + ); let lv_min = lv_node_exploded.min(arena); let lv_max = lv_node_exploded.max(arena); diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs index dea5787d1e7c..1337fa65c9a8 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs @@ -678,7 +678,7 @@ fn expand_expression_rec( } } }, - Expr::Explode { input, skip_empty } => { + Expr::Explode { input, options } => { _ = expand_single( input.as_ref(), ignored_selector_columns, @@ -687,7 +687,7 @@ fn expand_expression_rec( opt_flags, |e| Expr::Explode { input: Arc::new(e), - skip_empty: *skip_empty, + options: *options, }, )? }, diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs index cf4c6ac5fec5..27e298b6ef0d 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs @@ -122,9 +122,9 @@ pub(super) fn to_aexpr_impl( let (v, output_name) = match expr { Expr::Element => (AExpr::Element, PlSmallStr::EMPTY), - Expr::Explode { input, skip_empty } => { + Expr::Explode { input, options } => { let (expr, output_name) = recurse_arc!(input)?; - (AExpr::Explode { expr, skip_empty }, output_name) + (AExpr::Explode { expr, options }, output_name) }, Expr::Alias(e, name) => return Ok((recurse_arc!(e)?.0, name)), Expr::Literal(lv) => { diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs index 6637a48a9154..bdfe412e1633 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs @@ -88,7 +88,7 @@ pub(super) fn convert_functions( #[cfg(feature = "array_count")] A::CountMatches => IA::CountMatches, A::Shift => IA::Shift, - A::Explode { skip_empty } => IA::Explode { skip_empty }, + A::Explode(options) => IA::Explode(options), A::Concat => IA::Concat, A::Slice(offset, length) => IA::Slice(offset, length), #[cfg(feature = "array_to_struct")] diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 89d46a01d671..6e8cbedcd667 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -7,9 +7,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { match expr { AExpr::Element => Expr::Element, - AExpr::Explode { expr, skip_empty } => Expr::Explode { + AExpr::Explode { expr, options } => Expr::Explode { input: Arc::new(node_to_expr(expr, expr_arena)), - skip_empty, + options, }, AExpr::Column(a) => Expr::Column(a), AExpr::Literal(s) => Expr::Literal(s), @@ -308,7 +308,7 @@ pub fn ir_function_to_dsl(input: Vec, function: IRFunctionExpr) -> Expr { IA::CountMatches => A::CountMatches, IA::Shift => A::Shift, IA::Slice(offset, length) => A::Slice(offset, length), - IA::Explode { skip_empty } => A::Explode { skip_empty }, + IA::Explode(options) => A::Explode(options), #[cfg(feature = "array_to_struct")] IA::ToStruct(ng) => A::ToStruct(ng), }) diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 9c8771cc36aa..58e70d014d93 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -1,4 +1,4 @@ -use std::fmt::{self, Display, Formatter}; +use std::fmt::{self, Display, Formatter, Write}; use polars_core::frame::DataFrame; use polars_core::schema::Schema; @@ -377,13 +377,16 @@ impl Display for ExprIRDisplay<'_> { } }, Len => write!(f, "len()"), - Explode { expr, skip_empty } => { + Explode { expr, options } => { let expr = self.with_root(expr); - if *skip_empty { - write!(f, "{expr}.explode(skip_empty)") - } else { - write!(f, "{expr}.explode()") + write!(f, "{expr}.explode(")?; + match (options.empty_as_null, options.keep_nulls) { + (true, true) => {}, + (true, false) => f.write_str("keep_nulls=false")?, + (false, true) => f.write_str("empty_as_null=false")?, + (false, false) => f.write_str("empty_as_null=false, keep_nulls=false")?, } + f.write_char(')') }, Column(name) => write!(f, "col(\"{name}\")"), Literal(v) => write!(f, "{v:?}"), diff --git a/crates/polars-plan/src/plans/ir/tree_format.rs b/crates/polars-plan/src/plans/ir/tree_format.rs index ab004b6eda8b..b2b20dc44abf 100644 --- a/crates/polars-plan/src/plans/ir/tree_format.rs +++ b/crates/polars-plan/src/plans/ir/tree_format.rs @@ -1,4 +1,4 @@ -use std::fmt; +use std::fmt::{self, Write}; use polars_core::error::*; use polars_utils::format_list_truncated; @@ -25,14 +25,16 @@ impl fmt::Display for TreeFmtAExpr<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self.0 { AExpr::Element => "element()", - AExpr::Explode { - expr: _, - skip_empty: false, - } => "explode", - AExpr::Explode { - expr: _, - skip_empty: true, - } => "explode(skip_empty)", + AExpr::Explode { expr: _, options } => { + f.write_str("explode(")?; + match (options.empty_as_null, options.keep_nulls) { + (true, true) => {}, + (true, false) => f.write_str("keep_nulls=false")?, + (false, true) => f.write_str("empty_as_null=false")?, + (false, false) => f.write_str("empty_as_null=false, keep_nulls=false")?, + } + return f.write_char(')'); + }, AExpr::Column(name) => return write!(f, "col({name})"), AExpr::Literal(lv) => return write!(f, "lit({lv:?})"), AExpr::BinaryExpr { op, .. } => return write!(f, "binary: {op}"), diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index bd78b266b05f..4f200955703e 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -72,7 +72,7 @@ impl TreeWalker for Expr { }), Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? }, Function { input, function } => Function { input: input.into_iter().map(f).collect::>()?, function }, - Explode { input, skip_empty } => Explode { input: am(input, f)?, skip_empty }, + Explode { input, options } => Explode { input: am(input, f)?, options }, Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? }, #[cfg(feature = "dynamic_group_by")] Rolling { function, index_column, period, offset, closed_window } => Rolling { function: am(function, &mut f)?, index_column: am(index_column, &mut f)?, period, offset, closed_window }, @@ -197,13 +197,13 @@ impl AExpr { ( Explode { expr: _, - skip_empty: l_skip_empty, + options: l_options, }, Explode { expr: _, - skip_empty: r_skip_empty, + options: r_options, }, - ) => l_skip_empty == r_skip_empty, + ) => l_options == r_options, ( SortBy { sort_options: l_sort_options, diff --git a/crates/polars-python/src/expr/array.rs b/crates/polars-python/src/expr/array.rs index ba686e56da2b..9d056dc97ae2 100644 --- a/crates/polars-python/src/expr/array.rs +++ b/crates/polars-python/src/expr/array.rs @@ -152,8 +152,15 @@ impl PyExpr { self.inner.clone().arr().shift(n.inner).into() } - fn arr_explode(&self) -> Self { - self.inner.clone().arr().explode().into() + fn arr_explode(&self, empty_as_null: bool, keep_nulls: bool) -> Self { + self.inner + .clone() + .arr() + .explode(ExplodeOptions { + empty_as_null, + keep_nulls, + }) + .into() } fn arr_eval(&self, expr: PyExpr, as_list: bool) -> Self { diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 095d90a0b6be..059835d7c109 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -439,8 +439,14 @@ impl PyExpr { self.inner.clone().is_last_distinct().into() } - fn explode(&self) -> Self { - self.inner.clone().explode().into() + fn explode(&self, empty_as_null: bool, keep_nulls: bool) -> Self { + self.inner + .clone() + .explode(ExplodeOptions { + empty_as_null, + keep_nulls, + }) + .into() } fn gather_every(&self, n: usize, offset: usize) -> Self { diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 0c0eae1fd19d..f9f1f1ebb4c7 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -2,7 +2,8 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::prelude::{ - DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, polars_err, + DataType, ExplodeOptions, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, + polars_err, }; use polars_lazy::dsl::Expr; use polars_ops::chunked_array::UnicodeForm; @@ -1527,7 +1528,12 @@ impl SQLFunctionVisitor<'_> { ArraySum => self.visit_unary(|e| e.list().sum()), ArrayToString => self.visit_arr_to_string(), ArrayUnique => self.visit_unary(|e| e.list().unique()), - Explode => self.visit_unary(|e| e.explode()), + Explode => self.visit_unary(|e| { + e.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + }), // ---- // Column selection diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index b92766847095..d99db55c622e 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -526,7 +526,16 @@ fn test_arr_agg() { ), ( "SELECT unnest(ARRAY_AGG(DISTINCT a)) FROM df", - vec![col("a").unique_stable().implode().explode().alias("a")], + vec![ + col("a") + .unique_stable() + .implode() + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + .alias("a"), + ], ), ( "SELECT ARRAY_AGG(a ORDER BY b LIMIT 2) FROM df", diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 7a80e307e154..9d546422f225 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -610,7 +610,7 @@ fn lower_exprs_with_ctx( AExpr::Explode { expr: inner, - skip_empty, + options, } => { // While explode is streamable, it is not elementwise, so we // have to transform it to a select node. @@ -618,7 +618,7 @@ fn lower_exprs_with_ctx( let exploded_name = unique_column_name(); let trans_inner = ctx.expr_arena.add(AExpr::Explode { expr: trans_exprs[0], - skip_empty, + options, }); let explode_expr = ExprIR::new(trans_inner, OutputName::Alias(exploded_name.clone())); @@ -1044,6 +1044,8 @@ fn lower_exprs_with_ctx( options: _, } if is_scalar_ae(inner_exprs[1].node(), ctx.expr_arena) => { // Translate left and right side separately (they could have different lengths). + + use polars_core::prelude::ExplodeOptions; let left_on_name = unique_column_name(); let right_on_name = unique_column_name(); let (trans_input_left, trans_expr_left) = @@ -1051,10 +1053,15 @@ fn lower_exprs_with_ctx( let right_expr_exploded_node = match ctx.expr_arena.get(inner_exprs[1].node()) { // expr.implode().explode() ~= expr (and avoids rechunking) AExpr::Agg(IRAggExpr::Implode(n)) => *n, - _ => ctx.expr_arena.add(AExpr::Explode { - expr: inner_exprs[1].node(), - skip_empty: true, - }), + _ => AExprBuilder::new_from_node(inner_exprs[1].node()) + .explode( + ctx.expr_arena, + ExplodeOptions { + empty_as_null: false, + keep_nulls: true, + }, + ) + .node(), }; let (trans_input_right, trans_expr_right) = lower_exprs_with_ctx(input, &[right_expr_exploded_node], ctx)?; diff --git a/crates/polars-testing/src/asserts/utils.rs b/crates/polars-testing/src/asserts/utils.rs index 82bc9007403b..f90728930f06 100644 --- a/crates/polars-testing/src/asserts/utils.rs +++ b/crates/polars-testing/src/asserts/utils.rs @@ -446,8 +446,14 @@ fn assert_series_nested_values_equal( let s2_series = Series::new("".into(), std::slice::from_ref(&s2)); match assert_series_values_equal( - &s1_series.explode(false)?, - &s2_series.explode(false)?, + &s1_series.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?, + &s2_series.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?, true, check_exact, check_dtypes, diff --git a/crates/polars/tests/it/core/list.rs b/crates/polars/tests/it/core/list.rs index 122454f336c3..e9f518eec065 100644 --- a/crates/polars/tests/it/core/list.rs +++ b/crates/polars/tests/it/core/list.rs @@ -10,7 +10,12 @@ fn test_to_list_logical() -> PolarsResult<()> { // check if dtype is maintained all the way to formatting assert!(s.contains("[2021-01-01, 2021-01-02, 2021-01-03]")); - let expl = out.explode(false).unwrap(); + let expl = out + .explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + }) + .unwrap(); assert_eq!(expl.dtype(), &DataType::Date); Ok(()) } diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index bedb6e0cd9a2..c02190d3a56c 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -358,10 +358,13 @@ fn test_binary_group_consistency() -> PolarsResult<()> { assert_eq!(out.dtype(), &DataType::List(Box::new(DataType::String))); assert_eq!( - out.explode(false)? - .str()? - .into_no_null_iter() - .collect::>(), + out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true + })? + .str()? + .into_no_null_iter() + .collect::>(), &["a", "b", "c", "d"] ); diff --git a/crates/polars/tests/it/lazy/expressions/slice.rs b/crates/polars/tests/it/lazy/expressions/slice.rs index 8dfb52b4a68d..03018d820bd3 100644 --- a/crates/polars/tests/it/lazy/expressions/slice.rs +++ b/crates/polars/tests/it/lazy/expressions/slice.rs @@ -17,7 +17,10 @@ fn test_slice_args() -> PolarsResult<()> { .agg([col("vals").slice(lit(0i64), (len() * lit(0.2)).cast(DataType::Int32))]) .collect()?; - let out = df.column("vals")?.explode(false)?; + let out = df.column("vals")?.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let out = out.i32().unwrap(); assert_eq!( out.into_no_null_iter().collect::>(), diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 524b29e61406..83001ba8b13f 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -174,7 +174,10 @@ fn test_literal_window_fn() -> PolarsResult<()> { let out = out.column("foo")?; assert!(matches!(out.dtype(), DataType::List(_))); - let flat = out.explode(false)?; + let flat = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; let flat = flat.i32()?; assert_eq!( Vec::from(flat), diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs index 31d33100615a..1605b5edb95d 100644 --- a/crates/polars/tests/it/lazy/group_by.rs +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -122,7 +122,10 @@ fn test_group_by_agg_list_with_not_aggregated() -> PolarsResult<()> { .collect()?; let out = out.column("value")?; - let out = out.explode(false)?; + let out = out.explode(ExplodeOptions { + empty_as_null: true, + keep_nulls: true, + })?; assert_eq!( out, Column::new("value".into(), &[0, 2, 1, 3, 2, 2, 7, 2, 3, 1, 2, 1]) diff --git a/py-polars/src/polars/_plr.pyi b/py-polars/src/polars/_plr.pyi index f647e0f353cc..41b0c10b584a 100644 --- a/py-polars/src/polars/_plr.pyi +++ b/py-polars/src/polars/_plr.pyi @@ -1255,7 +1255,7 @@ class PyExpr: def approx_n_unique(self) -> PyExpr: ... def is_first_distinct(self) -> PyExpr: ... def is_last_distinct(self) -> PyExpr: ... - def explode(self) -> PyExpr: ... + def explode(self, *, empty_as_null: bool, keep_nulls: bool) -> PyExpr: ... def gather_every(self, n: int, offset: int) -> PyExpr: ... def slice(self, offset: PyExpr, length: PyExpr) -> PyExpr: ... def append(self, other: PyExpr, upcast: bool) -> PyExpr: ... @@ -1427,7 +1427,7 @@ class PyExpr: ) -> PyExpr: ... def arr_tail(self, n: PyExpr, as_array: bool) -> PyExpr: ... def arr_shift(self, n: PyExpr) -> PyExpr: ... - def arr_explode(self) -> PyExpr: ... + def arr_explode(self, *, empty_as_null: bool, keep_nulls: bool) -> PyExpr: ... def arr_eval(self, expr: PyExpr, *, as_list: bool) -> PyExpr: ... def arr_agg(self, expr: PyExpr) -> PyExpr: ... diff --git a/py-polars/src/polars/expr/array.py b/py-polars/src/polars/expr/array.py index 1c83beb95b59..e75aa9d299dc 100644 --- a/py-polars/src/polars/expr/array.py +++ b/py-polars/src/polars/expr/array.py @@ -759,10 +759,17 @@ def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr: separator_pyexpr = parse_into_expression(separator, str_as_lit=True) return wrap_expr(self._pyexpr.arr_join(separator_pyexpr, ignore_nulls)) - def explode(self) -> Expr: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Expr: """ Returns a column with a separate row for every array element. + Parameters + ---------- + empty_as_null + Explode an empty array into a `null`. + keep_nulls + Explode a `null` array into a `null`. + Returns ------- Expr @@ -788,7 +795,9 @@ def explode(self) -> Expr: │ 6 │ └─────┘ """ - return wrap_expr(self._pyexpr.arr_explode()) + return wrap_expr( + self._pyexpr.arr_explode(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + ) def contains(self, item: IntoExpr, *, nulls_equal: bool = True) -> Expr: """ diff --git a/py-polars/src/polars/expr/expr.py b/py-polars/src/polars/expr/expr.py index 2e21c230219b..8711f54f29ff 100644 --- a/py-polars/src/polars/expr/expr.py +++ b/py-polars/src/polars/expr/expr.py @@ -4936,14 +4936,21 @@ def flatten(self) -> Expr: │ b ┆ [2, 3, 4] │ └───────┴───────────┘ """ - return wrap_expr(self._pyexpr.explode()) + return self.explode(empty_as_null=True, keep_nulls=True) - def explode(self) -> Expr: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Expr: """ Explode a list expression. This means that every item is expanded to a new row. + Parameters + ---------- + empty_as_null + Explode an empty list/array into a `null`. + keep_nulls + Explode a `null` list/array into a `null`. + Returns ------- Expr @@ -4977,7 +4984,9 @@ def explode(self) -> Expr: │ 4 │ └────────┘ """ - return wrap_expr(self._pyexpr.explode()) + return wrap_expr( + self._pyexpr.explode(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + ) def implode(self) -> Expr: """ diff --git a/py-polars/src/polars/expr/list.py b/py-polars/src/polars/expr/list.py index fc30f3690cd4..f6c129eb7721 100644 --- a/py-polars/src/polars/expr/list.py +++ b/py-polars/src/polars/expr/list.py @@ -1059,10 +1059,17 @@ def tail(self, n: int | str | Expr = 5) -> Expr: n_pyexpr = parse_into_expression(n) return wrap_expr(self._pyexpr.list_tail(n_pyexpr)) - def explode(self) -> Expr: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Expr: """ Returns a column with a separate row for every list element. + Parameters + ---------- + empty_as_null + Explode an empty list into a `null`. + keep_nulls + Explode a `null` list into a `null`. + Returns ------- Expr @@ -1090,7 +1097,9 @@ def explode(self) -> Expr: │ 6 │ └─────┘ """ - return wrap_expr(self._pyexpr.explode()) + return wrap_expr( + self._pyexpr.explode(empty_as_null=empty_as_null, keep_nulls=keep_nulls) + ) def count_matches(self, element: IntoExpr) -> Expr: """ diff --git a/py-polars/src/polars/series/array.py b/py-polars/src/polars/series/array.py index 00383a4342d3..12ba164a7b3c 100644 --- a/py-polars/src/polars/series/array.py +++ b/py-polars/src/polars/series/array.py @@ -605,10 +605,17 @@ def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Serie """ - def explode(self) -> Series: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Series: """ Returns a column with a separate row for every array element. + Parameters + ---------- + empty_as_null + Explode an empty array into a `null`. + keep_nulls + Explode a `null` array into a `null`. + Returns ------- Series diff --git a/py-polars/src/polars/series/list.py b/py-polars/src/polars/series/list.py index 425fffe9c370..5e929ff87b3a 100644 --- a/py-polars/src/polars/series/list.py +++ b/py-polars/src/polars/series/list.py @@ -817,10 +817,17 @@ def tail(self, n: int | Expr = 5) -> Series: ] """ - def explode(self) -> Series: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Series: """ Returns a column with a separate row for every list element. + Parameters + ---------- + empty_as_null + Explode an empty list into a `null`. + keep_nulls + Explode a `null` list into a `null`. + Returns ------- Series diff --git a/py-polars/src/polars/series/series.py b/py-polars/src/polars/series/series.py index 33c380b5eff3..1b3b4ba1149d 100644 --- a/py-polars/src/polars/series/series.py +++ b/py-polars/src/polars/series/series.py @@ -4186,12 +4186,19 @@ def is_duplicated(self) -> Series: ] """ - def explode(self) -> Series: + def explode(self, *, empty_as_null: bool = True, keep_nulls: bool = True) -> Series: """ Explode a list Series. This means that every item is expanded to a new row. + Parameters + ---------- + empty_as_null + Explode an empty list into a `null`. + keep_nulls + Explode a `null` list into a `null`. + Returns ------- Series diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 93fa08f24033..c6011a0710a2 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -2,11 +2,13 @@ import pyarrow as pa import pytest +from hypothesis import given import polars as pl import polars.selectors as cs from polars.exceptions import ShapeError from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series def test_explode_multiple() -> None: @@ -458,3 +460,134 @@ def test_explode_17648() -> None: def test_explode_struct_nulls() -> None: df = pl.DataFrame({"A": [[{"B": 1}], [None], []]}) assert df.explode("A").to_dict(as_series=False) == {"A": [{"B": 1}, None, None]} + + +def test_explode_basic() -> None: + s = pl.Series + + assert_series_equal(s([[1, 2, 3]]).explode(), pl.Series([1, 2, 3])) + assert_series_equal(s([[1, 2, 3], None]).explode(), pl.Series([1, 2, 3, None])) + assert_series_equal(s([[1, 2, 3], []]).explode(), pl.Series([1, 2, 3, None])) + masked = ( + s([[1, 2, 3], [1, 2], [1, 2]]) + .to_frame() + .select(pl.when(pl.Series([True, False, True])).then(pl.col(""))) + .to_series() + ) + assert_series_equal(masked.explode(), pl.Series([1, 2, 3, None, 1, 2])) + masked = ( + s([[1, 2, 3], [], [1, 2]]) + .to_frame() + .select(pl.when(pl.Series([True, False, True])).then(pl.col(""))) + .to_series() + ) + assert_series_equal(masked.explode(), pl.Series([1, 2, 3, None, 1, 2])) + + assert_series_equal( + s([[1, 2, 3]]).explode(empty_as_null=False, keep_nulls=False), + pl.Series([1, 2, 3]), + ) + + assert_series_equal(s([[1, 2, 3], None]).explode(), pl.Series([1, 2, 3, None])) + assert_series_equal( + s([[1, 2, 3], None]).explode(keep_nulls=False), pl.Series([1, 2, 3]) + ) + assert_series_equal( + s([[1, 2, 3], [None]]).explode(keep_nulls=False), pl.Series([1, 2, 3, None]) + ) + + assert_series_equal(s([[1, 2, 3], []]).explode(), pl.Series([1, 2, 3, None])) + assert_series_equal( + s([[1, 2, 3], []]).explode(empty_as_null=False), pl.Series([1, 2, 3]) + ) + assert_series_equal( + s([[1, 2, 3], [None]]).explode(empty_as_null=False), pl.Series([1, 2, 3, None]) + ) + + +@given(s=series(min_size=1)) +@pytest.mark.parametrize("empty_as_null", [False, True]) +@pytest.mark.parametrize("keep_nulls", [False, True]) +def test_explode_parametric( + s: pl.Series, empty_as_null: bool, keep_nulls: bool +) -> None: + a = {"empty_as_null": empty_as_null, "keep_nulls": keep_nulls} + si = s.implode() + + empty_list_item = s.clear(1) if empty_as_null else s.clear() + null_list_item = s.clear(1) if keep_nulls else s.clear() + + assert_series_equal(si.explode(**a), s) + assert_series_equal(s.clear().implode().explode(**a), empty_list_item) + assert_series_equal(si.clear(1).explode(**a), null_list_item) + + assert_series_equal( + pl.concat([si, s.clear().implode(), si]).explode(**a), + pl.concat([s, empty_list_item, s]), + ) + assert_series_equal( + pl.concat([si, si.clear(1), si]).explode(**a), pl.concat([s, null_list_item, s]) + ) + + for mask in [ + (False, False, False), + (True, False, True), + (False, False, True), + (True, False, False), + (False, True, False), + ]: + masked = ( + pl.concat([si, si, si]) + .to_frame() + .select(pl.when(pl.Series(mask)).then(pl.col(s.name)).alias(s.name)) + .to_series() + ) + assert_series_equal( + masked.explode(**a), pl.concat([s if m else null_list_item for m in mask]) + ) + + for size in [2, 3, 7, 15]: + assert_series_equal(pl.concat([si] * size).explode(**a), pl.concat([s] * size)) + + assert_series_equal( + pl.concat([s.clear().implode()] + [si] * size).explode(**a), + pl.concat([empty_list_item] + [s] * size), + ) + assert_series_equal( + pl.concat([si] * size + [s.clear().implode()]).explode(**a), + pl.concat([s] * size + [empty_list_item]), + ) + + assert_series_equal( + pl.concat([si.clear(1)] + [si] * size).explode(**a), + pl.concat([null_list_item] + [s] * size), + ) + assert_series_equal( + pl.concat([si] * size + [si.clear(1)]).explode(**a), + pl.concat([s] * size + [null_list_item]), + ) + + +def test_explode_array_parameters() -> None: + s = pl.Series("a", [[1, 2, 3], [4, 5, 6], [7, 8, 9]], pl.Array(pl.Int64, 3)) + assert_series_equal(s.explode(), pl.Series("a", list(range(1, 10)), pl.Int64)) + + s = pl.Series("a", [[1, 2, 3], [4, 5, 6], None], pl.Array(pl.Int64, 3)) + assert_series_equal( + s.explode(), pl.Series("a", list(range(1, 7)) + [None], pl.Int64) + ) + assert_series_equal( + s.explode(keep_nulls=False), pl.Series("a", list(range(1, 7)), pl.Int64) + ) + + s = pl.Series("a", [[], [], None], pl.Array(pl.Int64, 0)) + assert_series_equal(s.explode(), pl.Series("a", [None] * 3, pl.Int64)) + assert_series_equal( + s.explode(keep_nulls=False), pl.Series("a", [None] * 2, pl.Int64) + ) + assert_series_equal( + s.explode(empty_as_null=False), pl.Series("a", [None], pl.Int64) + ) + assert_series_equal( + s.explode(empty_as_null=False, keep_nulls=False), pl.Series("a", [], pl.Int64) + )