Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/polars-ops/src/chunked_array/array/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rayon::prelude::*;

use super::*;

pub type ArrToStructNameGenerator = Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>;
pub type ArrToStructNameGenerator = Arc<dyn Fn(usize) -> PolarsResult<PlSmallStr> + Send + Sync>;

pub fn arr_default_struct_name_gen(idx: usize) -> PlSmallStr {
format_pl_smallstr!("field_{idx}")
Expand All @@ -21,7 +21,7 @@ pub trait ToStruct: AsArray {

let name_generator = name_generator
.as_deref()
.unwrap_or(&arr_default_struct_name_gen);
.unwrap_or(&|i| Ok(arr_default_struct_name_gen(i)));

let fields = POOL.install(|| {
(0..n_fields)
Expand All @@ -31,9 +31,9 @@ pub trait ToStruct: AsArray {
&Int64Chunked::from_slice(PlSmallStr::EMPTY, &[i as i64]),
true,
)
.map(|mut s| {
s.rename(name_generator(i).clone());
s
.and_then(|mut s| {
s.rename(name_generator(i)?.clone());
Ok(s)
})
})
.collect::<PolarsResult<Vec<_>>>()
Expand Down
92 changes: 14 additions & 78 deletions crates/polars-ops/src/chunked_array/list/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@ use rayon::prelude::*;
use super::*;

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum ListToStructArgs {
FixedWidth(Arc<[PlSmallStr]>),
InferWidth {
infer_field_strategy: ListToStructWidthStrategy,
// Can serialize only when None (null), so we override to `()` which serializes to null.
#[cfg_attr(feature = "dsl-schema", schemars(with = "()"))]
get_index_name: Option<NameGenerator>,
/// If this is None, it means unbounded.
max_fields: Option<usize>,
Expand All @@ -29,42 +26,6 @@ pub enum ListToStructWidthStrategy {
}

impl ListToStructArgs {
pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult<DataType> {
let DataType::List(inner_dtype) = input_dtype else {
polars_bail!(
InvalidOperation:
"attempted list to_struct on non-list dtype: {}",
input_dtype
);
};
let inner_dtype = inner_dtype.as_ref();

match self {
Self::FixedWidth(names) => Ok(DataType::Struct(
names
.iter()
.map(|x| Field::new(x.clone(), inner_dtype.clone()))
.collect::<Vec<_>>(),
)),
Self::InferWidth {
get_index_name,
max_fields: Some(max_fields),
..
} => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);
Ok(DataType::Struct(
(0..*max_fields)
.map(|i| Field::new(get_index_name_func(i), inner_dtype.clone()))
.collect::<Vec<_>>(),
))
},
Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)),
}
}

fn det_n_fields(&self, ca: &ListChunked) -> usize {
match self {
Self::FixedWidth(v) => v.len(),
Expand Down Expand Up @@ -117,7 +78,7 @@ impl ListToStructArgs {
}
}

fn set_output_names(&self, columns: &mut [Series]) {
fn set_output_names(&self, columns: &mut [Series]) -> PolarsResult<()> {
match self {
Self::FixedWidth(v) => {
assert_eq!(columns.len(), v.len());
Expand All @@ -127,24 +88,29 @@ impl ListToStructArgs {
}
},
Self::InferWidth { get_index_name, .. } => {
let get_index_name_func = get_index_name.as_ref().map_or(
&_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr,
|x| x.0.as_ref(),
);
let get_index_name_func = get_index_name
.as_ref()
.map_or(&(|i| Ok(_default_struct_name_gen(i))) as _, |x| {
x.0.as_ref()
});

for (i, c) in columns.iter_mut().enumerate() {
c.rename(get_index_name_func(i));
c.rename(get_index_name_func(i)?);
}
},
}

Ok(())
}
}

#[derive(Clone)]
pub struct NameGenerator(pub Arc<dyn Fn(usize) -> PlSmallStr + Send + Sync>);
pub struct NameGenerator(pub Arc<dyn Fn(usize) -> PolarsResult<PlSmallStr> + Send + Sync>);

impl NameGenerator {
pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self {
pub fn from_func(
func: impl Fn(usize) -> PolarsResult<PlSmallStr> + Send + Sync + 'static,
) -> Self {
Self(Arc::new(func))
}
}
Expand Down Expand Up @@ -189,40 +155,10 @@ pub trait ToStruct: AsList {
.collect::<PolarsResult<Vec<_>>>()
})?;

args.set_output_names(&mut fields);
args.set_output_names(&mut fields)?;

StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter())
}
}

impl ToStruct for ListChunked {}

#[cfg(feature = "serde")]
mod _serde_impl {
use super::*;

impl serde::Serialize for NameGenerator {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
Err(S::Error::custom(
"cannot serialize name generator function for to_struct, \
consider passing a list of field names instead.",
))
}
}

impl<'de> serde::Deserialize<'de> for NameGenerator {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
Err(D::Error::custom(
"invalid data: attempted to deserialize list::to_struct::NameGenerator",
))
}
}
}
2 changes: 1 addition & 1 deletion crates/polars-plan/dsl-schema.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
83d363596632a622e78f7283b8e8d66cd23333f9e8382a220d54d85ad51cdc11
9e05f86bc8776540ba40685e8e53e90d4b83e86ea25bf45ed4372be1e17b9d88
28 changes: 2 additions & 26 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use polars_core::prelude::*;
#[cfg(feature = "array_to_struct")]
use polars_ops::chunked_array::array::{
ArrToStructNameGenerator, ToStruct, arr_default_struct_name_gen,
};

use crate::dsl::function_expr::ArrayFunction;
use crate::prelude::*;
Expand Down Expand Up @@ -147,28 +143,8 @@ impl ArrayNameSpace {
}

#[cfg(feature = "array_to_struct")]
pub fn to_struct(self, name_generator: Option<ArrToStructNameGenerator>) -> PolarsResult<Expr> {
Ok(self.0.map_with_fmt_str(
move |s| {
s.array()?
.to_struct(name_generator.clone())
.map(|s| Some(s.into_column()))
},
GetOutput::map_dtype(move |dt: &DataType| {
let DataType::Array(inner, width) = dt else {
polars_bail!(InvalidOperation: "expected Array type, got: {}", dt)
};

let fields = (0..*width)
.map(|i| {
let name = arr_default_struct_name_gen(i);
Field::new(name, inner.as_ref().clone())
})
.collect();
Ok(DataType::Struct(fields))
}),
"arr.to_struct",
))
pub fn to_struct(self, name_generator: Option<DslNameGenerator>) -> Expr {
self.0.map_unary(ArrayFunction::ToStruct(name_generator))
}

/// Slice every subarray.
Expand Down
14 changes: 13 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::fmt;

use polars_core::prelude::SortOptions;

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
use super::FunctionExpr;

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum ArrayFunction {
Expand Down Expand Up @@ -38,6 +40,8 @@ pub enum ArrayFunction {
skip_empty: bool,
},
Concat,
#[cfg(feature = "array_to_struct")]
ToStruct(Option<super::DslNameGenerator>),
}

impl fmt::Display for ArrayFunction {
Expand Down Expand Up @@ -72,7 +76,15 @@ impl fmt::Display for ArrayFunction {
CountMatches => "count_matches",
Shift => "shift",
Explode { .. } => "explode",
#[cfg(feature = "array_to_struct")]
ToStruct(_) => "to_struct",
};
write!(f, "arr.{name}")
}
}

impl From<ArrayFunction> for FunctionExpr {
fn from(value: ArrayFunction) -> Self {
Self::ArrayExpr(value)
}
}
22 changes: 21 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
use super::*;

#[cfg(feature = "list_to_struct")]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum ListToStruct {
FixedWidth(Arc<[PlSmallStr]>),
InferWidth {
infer_field_strategy: ListToStructWidthStrategy,
get_index_name: Option<DslNameGenerator>,
/// If this is None, it means unbounded.
max_fields: Option<usize>,
},
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
Expand Down Expand Up @@ -56,7 +70,7 @@ pub enum ListFunction {
#[cfg(feature = "dtype-array")]
ToArray(usize),
#[cfg(feature = "list_to_struct")]
ToStruct(ListToStructArgs),
ToStruct(ListToStruct),
}

impl Display for ListFunction {
Expand Down Expand Up @@ -123,3 +137,9 @@ impl Display for ListFunction {
write!(f, "list.{name}")
}
}

impl From<ListFunction> for FunctionExpr {
fn from(value: ListFunction) -> Self {
Self::ListExpr(value)
}
}
5 changes: 5 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub use array::ArrayFunction;
#[cfg(feature = "cov")]
pub use correlation::CorrelationMethod;
pub use list::ListFunction;
#[cfg(feature = "list_to_struct")]
pub use list::ListToStruct;
pub use polars_core::datatypes::ReshapeDimension;
use polars_core::prelude::*;
#[cfg(feature = "random")]
Expand Down Expand Up @@ -785,3 +787,6 @@ impl Display for FunctionExpr {
write!(f, "{s}")
}
}

#[cfg(any(feature = "array_to_struct", feature = "list_to_struct"))]
pub type DslNameGenerator = PlanCallback<usize, String>;
5 changes: 2 additions & 3 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,8 @@ impl ListNameSpace {
/// an `upper_bound` of struct fields that will be set.
/// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression
/// will look in the current schema to determine which columns to select.
pub fn to_struct(self, args: ListToStructArgs) -> Expr {
self.0
.map_unary(FunctionExpr::ListExpr(ListFunction::ToStruct(args)))
pub fn to_struct(self, args: ListToStruct) -> Expr {
self.0.map_unary(ListFunction::ToStruct(args))
}

#[cfg(feature = "is_in")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use super::*;
// - changing a name, type, or meaning of a field or an enum variant
// - changing a default value of a field or a default enum variant
// - restricting the range of allowed values a field can have
pub static DSL_VERSION: (u16, u16) = (20, 1);
pub static DSL_VERSION: (u16, u16) = (20, 2);
static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
Loading
Loading