diff --git a/crates/polars-ops/src/chunked_array/array/to_struct.rs b/crates/polars-ops/src/chunked_array/array/to_struct.rs index d316e82888c0..07e8812f44f4 100644 --- a/crates/polars-ops/src/chunked_array/array/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/array/to_struct.rs @@ -5,7 +5,7 @@ use rayon::prelude::*; use super::*; -pub type ArrToStructNameGenerator = Arc PlSmallStr + Send + Sync>; +pub type ArrToStructNameGenerator = Arc PolarsResult + Send + Sync>; pub fn arr_default_struct_name_gen(idx: usize) -> PlSmallStr { format_pl_smallstr!("field_{idx}") @@ -21,7 +21,7 @@ pub trait ToStruct: AsArray { let name_generator = name_generator .as_deref() - .unwrap_or(&arr_default_struct_name_gen); + .unwrap_or(&|i| Ok(arr_default_struct_name_gen(i))); let fields = POOL.install(|| { (0..n_fields) @@ -31,9 +31,9 @@ pub trait ToStruct: AsArray { &Int64Chunked::from_slice(PlSmallStr::EMPTY, &[i as i64]), true, ) - .map(|mut s| { - s.rename(name_generator(i).clone()); - s + .and_then(|mut s| { + s.rename(name_generator(i)?.clone()); + Ok(s) }) }) .collect::>>() diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index aa4f26b379cc..bcedfce5fb22 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -6,14 +6,11 @@ use rayon::prelude::*; use super::*; #[derive(Clone, Eq, PartialEq, Hash, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] pub enum ListToStructArgs { FixedWidth(Arc<[PlSmallStr]>), InferWidth { infer_field_strategy: ListToStructWidthStrategy, // Can serialize only when None (null), so we override to `()` which serializes to null. - #[cfg_attr(feature = "dsl-schema", schemars(with = "()"))] get_index_name: Option, /// If this is None, it means unbounded. max_fields: Option, @@ -29,42 +26,6 @@ pub enum ListToStructWidthStrategy { } impl ListToStructArgs { - pub fn get_output_dtype(&self, input_dtype: &DataType) -> PolarsResult { - let DataType::List(inner_dtype) = input_dtype else { - polars_bail!( - InvalidOperation: - "attempted list to_struct on non-list dtype: {}", - input_dtype - ); - }; - let inner_dtype = inner_dtype.as_ref(); - - match self { - Self::FixedWidth(names) => Ok(DataType::Struct( - names - .iter() - .map(|x| Field::new(x.clone(), inner_dtype.clone())) - .collect::>(), - )), - Self::InferWidth { - get_index_name, - max_fields: Some(max_fields), - .. - } => { - let get_index_name_func = get_index_name.as_ref().map_or( - &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, - |x| x.0.as_ref(), - ); - Ok(DataType::Struct( - (0..*max_fields) - .map(|i| Field::new(get_index_name_func(i), inner_dtype.clone())) - .collect::>(), - )) - }, - Self::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), - } - } - fn det_n_fields(&self, ca: &ListChunked) -> usize { match self { Self::FixedWidth(v) => v.len(), @@ -117,7 +78,7 @@ impl ListToStructArgs { } } - fn set_output_names(&self, columns: &mut [Series]) { + fn set_output_names(&self, columns: &mut [Series]) -> PolarsResult<()> { match self { Self::FixedWidth(v) => { assert_eq!(columns.len(), v.len()); @@ -127,24 +88,29 @@ impl ListToStructArgs { } }, Self::InferWidth { get_index_name, .. } => { - let get_index_name_func = get_index_name.as_ref().map_or( - &_default_struct_name_gen as &dyn Fn(usize) -> PlSmallStr, - |x| x.0.as_ref(), - ); + let get_index_name_func = get_index_name + .as_ref() + .map_or(&(|i| Ok(_default_struct_name_gen(i))) as _, |x| { + x.0.as_ref() + }); for (i, c) in columns.iter_mut().enumerate() { - c.rename(get_index_name_func(i)); + c.rename(get_index_name_func(i)?); } }, } + + Ok(()) } } #[derive(Clone)] -pub struct NameGenerator(pub Arc PlSmallStr + Send + Sync>); +pub struct NameGenerator(pub Arc PolarsResult + Send + Sync>); impl NameGenerator { - pub fn from_func(func: impl Fn(usize) -> PlSmallStr + Send + Sync + 'static) -> Self { + pub fn from_func( + func: impl Fn(usize) -> PolarsResult + Send + Sync + 'static, + ) -> Self { Self(Arc::new(func)) } } @@ -189,40 +155,10 @@ pub trait ToStruct: AsList { .collect::>>() })?; - args.set_output_names(&mut fields); + args.set_output_names(&mut fields)?; StructChunked::from_series(ca.name().clone(), ca.len(), fields.iter()) } } impl ToStruct for ListChunked {} - -#[cfg(feature = "serde")] -mod _serde_impl { - use super::*; - - impl serde::Serialize for NameGenerator { - fn serialize(&self, _serializer: S) -> Result - where - S: serde::Serializer, - { - use serde::ser::Error; - Err(S::Error::custom( - "cannot serialize name generator function for to_struct, \ - consider passing a list of field names instead.", - )) - } - } - - impl<'de> serde::Deserialize<'de> for NameGenerator { - fn deserialize(_deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::Error; - Err(D::Error::custom( - "invalid data: attempted to deserialize list::to_struct::NameGenerator", - )) - } - } -} diff --git a/crates/polars-plan/dsl-schema.sha256 b/crates/polars-plan/dsl-schema.sha256 index 6270c8a02f7e..aabfbb376722 100644 --- a/crates/polars-plan/dsl-schema.sha256 +++ b/crates/polars-plan/dsl-schema.sha256 @@ -1 +1 @@ -83d363596632a622e78f7283b8e8d66cd23333f9e8382a220d54d85ad51cdc11 \ No newline at end of file +9e05f86bc8776540ba40685e8e53e90d4b83e86ea25bf45ed4372be1e17b9d88 \ No newline at end of file diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 0123e5b136c2..cc7ad626c1bc 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -1,8 +1,4 @@ use polars_core::prelude::*; -#[cfg(feature = "array_to_struct")] -use polars_ops::chunked_array::array::{ - ArrToStructNameGenerator, ToStruct, arr_default_struct_name_gen, -}; use crate::dsl::function_expr::ArrayFunction; use crate::prelude::*; @@ -147,28 +143,8 @@ impl ArrayNameSpace { } #[cfg(feature = "array_to_struct")] - pub fn to_struct(self, name_generator: Option) -> PolarsResult { - Ok(self.0.map_with_fmt_str( - move |s| { - s.array()? - .to_struct(name_generator.clone()) - .map(|s| Some(s.into_column())) - }, - GetOutput::map_dtype(move |dt: &DataType| { - let DataType::Array(inner, width) = dt else { - polars_bail!(InvalidOperation: "expected Array type, got: {}", dt) - }; - - let fields = (0..*width) - .map(|i| { - let name = arr_default_struct_name_gen(i); - Field::new(name, inner.as_ref().clone()) - }) - .collect(); - Ok(DataType::Struct(fields)) - }), - "arr.to_struct", - )) + pub fn to_struct(self, name_generator: Option) -> Expr { + self.0.map_unary(ArrayFunction::ToStruct(name_generator)) } /// Slice every subarray. diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 891849ee205a..8672f11072fd 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -2,7 +2,9 @@ use std::fmt; use polars_core::prelude::SortOptions; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +use super::FunctionExpr; + +#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] pub enum ArrayFunction { @@ -38,6 +40,8 @@ pub enum ArrayFunction { skip_empty: bool, }, Concat, + #[cfg(feature = "array_to_struct")] + ToStruct(Option), } impl fmt::Display for ArrayFunction { @@ -72,7 +76,15 @@ impl fmt::Display for ArrayFunction { CountMatches => "count_matches", Shift => "shift", Explode { .. } => "explode", + #[cfg(feature = "array_to_struct")] + ToStruct(_) => "to_struct", }; write!(f, "arr.{name}") } } + +impl From for FunctionExpr { + fn from(value: ArrayFunction) -> Self { + Self::ArrayExpr(value) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index a8149f30c254..4e7b14047de9 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -1,5 +1,19 @@ use super::*; +#[cfg(feature = "list_to_struct")] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] +pub enum ListToStruct { + FixedWidth(Arc<[PlSmallStr]>), + InferWidth { + infer_field_strategy: ListToStructWidthStrategy, + get_index_name: Option, + /// If this is None, it means unbounded. + max_fields: Option, + }, +} + #[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))] @@ -56,7 +70,7 @@ pub enum ListFunction { #[cfg(feature = "dtype-array")] ToArray(usize), #[cfg(feature = "list_to_struct")] - ToStruct(ListToStructArgs), + ToStruct(ListToStruct), } impl Display for ListFunction { @@ -123,3 +137,9 @@ impl Display for ListFunction { write!(f, "list.{name}") } } + +impl From for FunctionExpr { + fn from(value: ListFunction) -> Self { + Self::ListExpr(value) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 9d1246a489a7..1ce7506f3f5b 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -37,6 +37,8 @@ pub use array::ArrayFunction; #[cfg(feature = "cov")] pub use correlation::CorrelationMethod; pub use list::ListFunction; +#[cfg(feature = "list_to_struct")] +pub use list::ListToStruct; pub use polars_core::datatypes::ReshapeDimension; use polars_core::prelude::*; #[cfg(feature = "random")] @@ -785,3 +787,6 @@ impl Display for FunctionExpr { write!(f, "{s}") } } + +#[cfg(any(feature = "array_to_struct", feature = "list_to_struct"))] +pub type DslNameGenerator = PlanCallback; diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 74a00c49c652..90c3e552be4b 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -243,9 +243,8 @@ impl ListNameSpace { /// an `upper_bound` of struct fields that will be set. /// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression /// will look in the current schema to determine which columns to select. - pub fn to_struct(self, args: ListToStructArgs) -> Expr { - self.0 - .map_unary(FunctionExpr::ListExpr(ListFunction::ToStruct(args))) + pub fn to_struct(self, args: ListToStruct) -> Expr { + self.0.map_unary(ListFunction::ToStruct(args)) } #[cfg(feature = "is_in")] diff --git a/crates/polars-plan/src/dsl/plan.rs b/crates/polars-plan/src/dsl/plan.rs index 2f44346b79ca..d870e2b06808 100644 --- a/crates/polars-plan/src/dsl/plan.rs +++ b/crates/polars-plan/src/dsl/plan.rs @@ -48,7 +48,7 @@ use super::*; // - changing a name, type, or meaning of a field or an enum variant // - changing a default value of a field or a default enum variant // - restricting the range of allowed values a field can have -pub static DSL_VERSION: (u16, u16) = (20, 1); +pub static DSL_VERSION: (u16, u16) = (20, 2); static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION"; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/array.rs b/crates/polars-plan/src/plans/aexpr/function_expr/array.rs index 68f94c13be88..1399a34b7a0d 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/array.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/array.rs @@ -4,7 +4,7 @@ use polars_ops::chunked_array::array::*; use super::*; use crate::{map, map_as_slice}; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] #[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))] pub enum IRArrayFunction { Length, @@ -39,6 +39,8 @@ pub enum IRArrayFunction { }, Concat, Slice(i64, i64), + #[cfg(feature = "array_to_struct")] + ToStruct(Option), } impl IRArrayFunction { @@ -79,6 +81,23 @@ impl IRArrayFunction { Slice(offset, length) => { mapper.try_map_dtype(map_to_array_fixed_length(offset, length)) }, + #[cfg(feature = "array_to_struct")] + ToStruct(name_generator) => mapper.try_map_dtype(|dtype| { + let DataType::Array(inner, width) = dtype else { + polars_bail!(InvalidOperation: "expected Array type, got: {dtype}") + }; + + (0..*width) + .map(|i| { + let name = match name_generator { + None => arr_default_struct_name_gen(i), + Some(ng) => PlSmallStr::from_string(ng.call(i)?), + }; + Ok(Field::new(name, inner.as_ref().clone())) + }) + .collect::>>() + .map(DataType::Struct) + }), } } @@ -112,6 +131,8 @@ impl IRArrayFunction { | A::Shift | A::Slice(_, _) => FunctionOptions::elementwise(), A::Explode { .. } => FunctionOptions::row_separable(), + #[cfg(feature = "array_to_struct")] + A::ToStruct(_) => FunctionOptions::elementwise(), } } } @@ -177,6 +198,8 @@ impl Display for IRArrayFunction { Shift => "shift", Slice(_, _) => "slice", Explode { .. } => "explode", + #[cfg(feature = "array_to_struct")] + ToStruct(_) => "to_struct", }; write!(f, "arr.{name}") } @@ -214,6 +237,8 @@ impl From for SpecialEq> { Shift => map_as_slice!(shift), Explode { skip_empty } => map_as_slice!(explode, skip_empty), Slice(offset, length) => map!(slice, offset, length), + #[cfg(feature = "array_to_struct")] + ToStruct(ng) => map!(arr_to_struct, ng.clone()), } } } @@ -412,3 +437,12 @@ fn concat_arr_output_dtype( out_width, )) } + +#[cfg(feature = "array_to_struct")] +fn arr_to_struct(s: &Column, name_generator: Option) -> PolarsResult { + let name_generator = + name_generator.map(|f| Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>); + s.array()? + .to_struct(name_generator) + .map(IntoColumn::into_column) +} diff --git a/crates/polars-plan/src/plans/aexpr/function_expr/list.rs b/crates/polars-plan/src/plans/aexpr/function_expr/list.rs index 1389bce0621b..18a86e10bee6 100644 --- a/crates/polars-plan/src/plans/aexpr/function_expr/list.rs +++ b/crates/polars-plan/src/plans/aexpr/function_expr/list.rs @@ -60,7 +60,7 @@ pub enum IRListFunction { #[cfg(feature = "dtype-array")] ToArray(usize), #[cfg(feature = "list_to_struct")] - ToStruct(ListToStructArgs), + ToStruct(ListToStruct), } impl IRListFunction { @@ -128,7 +128,39 @@ impl IRListFunction { ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), NUnique => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "list_to_struct")] - ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), + ToStruct(args) => mapper.try_map_dtype(|dtype| { + let DataType::List(inner_dtype) = dtype else { + polars_bail!( + InvalidOperation: + "attempted list to_struct on non-list dtype: {dtype}", + ); + }; + let inner_dtype = inner_dtype.as_ref(); + + match args { + ListToStruct::FixedWidth(names) => Ok(DataType::Struct( + names + .iter() + .map(|x| Field::new(x.clone(), inner_dtype.clone())) + .collect::>(), + )), + ListToStruct::InferWidth { + get_index_name, + max_fields: Some(max_fields), + .. + } => (0..*max_fields) + .map(|i| { + let name = match get_index_name { + None => _default_struct_name_gen(i), + Some(ng) => PlSmallStr::from_string(ng.call(i)?), + }; + Ok(Field::new(name, inner_dtype.as_ref().clone())) + }) + .collect::>>() + .map(DataType::Struct), + ListToStruct::InferWidth { .. } => Ok(DataType::Unknown(UnknownKind::Any)), + } + }), } } @@ -180,9 +212,9 @@ impl IRListFunction { #[cfg(feature = "dtype-array")] L::ToArray(_) => FunctionOptions::elementwise(), #[cfg(feature = "list_to_struct")] - L::ToStruct(ListToStructArgs::FixedWidth(_)) => FunctionOptions::elementwise(), + L::ToStruct(ListToStruct::FixedWidth(_)) => FunctionOptions::elementwise(), #[cfg(feature = "list_to_struct")] - L::ToStruct(ListToStructArgs::InferWidth { .. }) => FunctionOptions::groupwise(), + L::ToStruct(ListToStruct::InferWidth { .. }) => FunctionOptions::length_preserving(), } } } @@ -779,8 +811,27 @@ pub(super) fn to_array(s: &Column, width: usize) -> PolarsResult { } #[cfg(feature = "list_to_struct")] -pub(super) fn to_struct(s: &Column, args: &ListToStructArgs) -> PolarsResult { - Ok(s.list()?.to_struct(args)?.into_series().into()) +pub(super) fn to_struct(s: &Column, args: &ListToStruct) -> PolarsResult { + let args = args.clone(); + let args = match args { + ListToStruct::FixedWidth(strs) => ListToStructArgs::FixedWidth(strs), + ListToStruct::InferWidth { + infer_field_strategy, + get_index_name, + max_fields, + } => { + let get_index_name = get_index_name.map(|f| { + NameGenerator(Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>) + }); + + ListToStructArgs::InferWidth { + infer_field_strategy, + get_index_name, + max_fields, + } + }, + }; + Ok(s.list()?.to_struct(&args)?.into_column()) } pub(super) fn n_unique(s: &Column) -> PolarsResult { diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs index 200edfdd7076..4019fa13964a 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs @@ -94,6 +94,8 @@ pub(super) fn convert_functions( A::Explode { skip_empty } => IA::Explode { skip_empty }, A::Concat => IA::Concat, A::Slice(offset, length) => IA::Slice(offset, length), + #[cfg(feature = "array_to_struct")] + A::ToStruct(ng) => IA::ToStruct(ng), }) }, F::BinaryExpr(binary_function) => { diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index ba92b4f588ce..766235f6995c 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -305,6 +305,8 @@ pub fn ir_function_to_dsl(input: Vec, function: IRFunctionExpr) -> Expr { IA::Shift => A::Shift, IA::Slice(offset, length) => A::Slice(offset, length), IA::Explode { skip_empty } => A::Explode { skip_empty }, + #[cfg(feature = "array_to_struct")] + IA::ToStruct(ng) => A::ToStruct(ng), }) }, IF::BinaryExpr(f) => { diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index ad350b597dd4..533431b0978b 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -1,7 +1,9 @@ -use std::fmt::Formatter; +use std::fmt::{self, Formatter}; use std::iter::FlatMap; use polars_core::prelude::*; +#[cfg(feature = "python")] +use polars_utils::python_function::PythonFunction; use self::visitor::{AexprNode, RewritingVisitor, TreeWalker}; use crate::constants::get_len_name; @@ -335,3 +337,116 @@ pub fn rename_columns( .unwrap() .node() } + +#[derive(Eq, PartialEq)] +pub enum PlanCallback { + #[cfg(feature = "python")] + Python(SpecialEq>), + Rust(SpecialEq PolarsResult + Send + Sync>>), +} + +impl fmt::Debug for PlanCallback { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PlanCallback::")?; + std::mem::discriminant(self).fmt(f) + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PlanCallback { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + #[cfg(feature = "python")] + if let Self::Python(v) = self { + return v.serialize(_serializer); + } + + Err(S::Error::custom(format!( + "cannot serialize 'opaque' function in {self:?}" + ))) + } +} + +#[cfg(feature = "serde")] +impl<'de, Args, Out> serde::Deserialize<'de> for PlanCallback { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python(SpecialEq::new(Arc::new( + polars_utils::python_function::PythonFunction::deserialize(_deserializer)?, + )))) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom("cannot deserialize PlanCallback")) + } + } +} + +#[cfg(feature = "dsl-schema")] +impl schemars::JsonSchema for PlanCallback { + fn schema_name() -> String { + "PlanCallback".to_owned() + } + + fn schema_id() -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PlanCallback")) + } + + fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema { + Vec::::json_schema(generator) + } +} + +impl std::hash::Hash for PlanCallback { + fn hash(&self, _state: &mut H) { + // no-op. + } +} + +impl Clone for PlanCallback { + fn clone(&self) -> Self { + match self { + #[cfg(feature = "python")] + Self::Python(p) => Self::Python(p.clone()), + Self::Rust(f) => Self::Rust(f.clone()), + } + } +} + +#[cfg(feature = "python")] +impl pyo3::IntoPyObject<'py>, Out: for<'py> pyo3::FromPyObject<'py>> + PlanCallback +{ + pub fn call(&self, args: Args) -> PolarsResult { + match self { + #[cfg(feature = "python")] + Self::Python(pyfn) => pyo3::Python::with_gil(|py| { + let out = pyfn.call1(py, (args,))?.extract::(py)?; + Ok(out) + }), + Self::Rust(f) => f(args), + } + } + + pub fn new_python(pyfn: PythonFunction) -> Self { + Self::Python(SpecialEq::new(Arc::new(pyfn))) + } +} + +#[cfg(not(feature = "python"))] +impl PlanCallback { + pub fn call(&self, args: Args) -> PolarsResult { + match self { + Self::Rust(f) => f(args), + } + } +} diff --git a/crates/polars-python/src/expr/array.rs b/crates/polars-python/src/expr/array.rs index 7f11eed828a8..d3d371b1b2a8 100644 --- a/crates/polars-python/src/expr/array.rs +++ b/crates/polars-python/src/expr/array.rs @@ -1,8 +1,7 @@ use polars::prelude::*; -use polars_ops::prelude::array::ArrToStructNameGenerator; -use polars_utils::pl_str::PlSmallStr; +use polars_plan::utils::PlanCallback; +use polars_utils::python_function::PythonObject; use pyo3::prelude::*; -use pyo3::pybacked::PyBackedStr; use pyo3::pymethods; use crate::error::PyPolarsErr; @@ -117,24 +116,9 @@ impl PyExpr { } #[pyo3(signature = (name_gen))] - fn arr_to_struct(&self, name_gen: Option) -> PyResult { - let name_gen = name_gen.map(|lambda| { - Arc::new(move |idx: usize| { - Python::with_gil(|py| { - let out = lambda.call1(py, (idx,)).unwrap(); - let out: PlSmallStr = (&*out.extract::(py).unwrap()).into(); - out - }) - }) as ArrToStructNameGenerator - }); - - Ok(self - .inner - .clone() - .arr() - .to_struct(name_gen) - .map_err(PyPolarsErr::from)? - .into()) + fn arr_to_struct(&self, name_gen: Option) -> Self { + let name_gen = name_gen.map(|o| PlanCallback::new_python(PythonObject(o))); + self.inner.clone().arr().to_struct(name_gen).into() } fn arr_slice(&self, offset: PyExpr, length: Option, as_array: bool) -> PyResult { diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index 7886c0999472..a1f1df7f9f34 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -1,8 +1,8 @@ -use std::borrow::Cow; - use polars::prelude::*; use polars::series::ops::NullBehavior; +use polars_plan::utils::PlanCallback; use polars_utils::pl_str::PlSmallStr; +use polars_utils::python_function::PythonObject; use pyo3::prelude::*; use pyo3::types::PySequence; @@ -207,21 +207,13 @@ impl PyExpr { name_gen: Option, upper_bound: Option, ) -> PyResult { - let name_gen = name_gen.map(|lambda| { - NameGenerator::from_func(move |idx: usize| { - Python::with_gil(|py| { - let out = lambda.call1(py, (idx,)).unwrap(); - let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); - out - }) - }) - }); + let name_gen = name_gen.map(|lambda| PlanCallback::new_python(PythonObject(lambda))); Ok(self .inner .clone() .list() - .to_struct(ListToStructArgs::InferWidth { + .to_struct(ListToStruct::InferWidth { infer_field_strategy: width_strat.0, get_index_name: name_gen, max_fields: upper_bound, @@ -235,7 +227,7 @@ impl PyExpr { .inner .clone() .list() - .to_struct(ListToStructArgs::FixedWidth( + .to_struct(ListToStruct::FixedWidth( names .try_iter()? .map(|x| Ok(x?.extract::>()?.0)) diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index fbf355ac48c7..7a4186c1e272 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -123,6 +123,7 @@ def test_hconcat_projection_pushdown_length_maintained() -> None: @pytest.mark.may_fail_auto_streaming +@pytest.mark.may_fail_cloud def test_unnest_columns_available() -> None: df = pl.DataFrame( {