Skip to content

Commit a5dc30d

Browse files
authored
fix: Various schema corrections (#18474)
1 parent 37e9ccd commit a5dc30d

File tree

9 files changed

+301
-173
lines changed

9 files changed

+301
-173
lines changed

crates/polars-plan/src/plans/aexpr/schema.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ impl AExpr {
5656
*nested = 0;
5757
Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE))
5858
},
59-
Window { function, .. } => {
59+
Window {
60+
function, options, ..
61+
} => {
62+
if let WindowType::Over(mapping) = options {
63+
*nested += matches!(mapping, WindowMapping::Join) as u8;
64+
}
6065
let e = arena.get(*function);
6166
e.to_field_impl(schema, arena, nested)
6267
},

crates/polars-plan/src/plans/conversion/expr_to_ir.rs

+15-59
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::*;
2+
use crate::plans::conversion::functions::convert_functions;
23

34
pub fn to_expr_ir(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<ExprIR> {
45
let mut state = ConversionContext::new();
@@ -40,12 +41,12 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<Node> {
4041
}
4142

4243
#[derive(Default)]
43-
struct ConversionContext {
44-
output_name: OutputName,
44+
pub(super) struct ConversionContext {
45+
pub(super) output_name: OutputName,
4546
/// Remove alias from the expressions and set as [`OutputName`].
46-
prune_alias: bool,
47+
pub(super) prune_alias: bool,
4748
/// If an `alias` is encountered prune and ignore it.
48-
ignore_alias: bool,
49+
pub(super) ignore_alias: bool,
4950
}
5051

5152
impl ConversionContext {
@@ -68,14 +69,17 @@ fn to_aexprs(
6869
.collect()
6970
}
7071

71-
fn set_function_output_name<F>(e: &[ExprIR], state: &mut ConversionContext, function_fmt: F)
72-
where
73-
F: FnOnce() -> Cow<'static, str>,
72+
pub(super) fn set_function_output_name<F>(
73+
e: &[ExprIR],
74+
state: &mut ConversionContext,
75+
function_fmt: F,
76+
) where
77+
F: FnOnce() -> PlSmallStr,
7478
{
7579
if state.output_name.is_none() {
7680
if e.is_empty() {
7781
let s = function_fmt();
78-
state.output_name = OutputName::LiteralLhs(PlSmallStr::from_str(s.as_ref()));
82+
state.output_name = OutputName::LiteralLhs(s);
7983
} else {
8084
state.output_name = e[0].output_name_inner().clone();
8185
}
@@ -117,7 +121,7 @@ fn to_aexpr_impl_materialized_lit(
117121

118122
/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation.
119123
#[recursive]
120-
fn to_aexpr_impl(
124+
pub(super) fn to_aexpr_impl(
121125
expr: Expr,
122126
arena: &mut Arena<AExpr>,
123127
state: &mut ConversionContext,
@@ -281,7 +285,7 @@ fn to_aexpr_impl(
281285
options,
282286
} => {
283287
let e = to_expr_irs(input, arena)?;
284-
set_function_output_name(&e, state, || Cow::Borrowed(options.fmt_str));
288+
set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str));
285289
AExpr::AnonymousFunction {
286290
input: e,
287291
function,
@@ -293,55 +297,7 @@ fn to_aexpr_impl(
293297
input,
294298
function,
295299
options,
296-
} => {
297-
match function {
298-
// This can be created by col(*).is_null() on empty dataframes.
299-
FunctionExpr::Boolean(
300-
BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal,
301-
) if input.is_empty() => {
302-
return to_aexpr_impl(lit(true), arena, state);
303-
},
304-
// Convert to binary expression as the optimizer understands those.
305-
// Don't exceed 128 expressions as we might stackoverflow.
306-
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
307-
if input.len() < 128 {
308-
let expr = input
309-
.into_iter()
310-
.reduce(|l, r| l.logical_and(r))
311-
.unwrap()
312-
.cast(DataType::Boolean);
313-
return to_aexpr_impl(expr, arena, state);
314-
}
315-
},
316-
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
317-
if input.len() < 128 {
318-
let expr = input
319-
.into_iter()
320-
.reduce(|l, r| l.logical_or(r))
321-
.unwrap()
322-
.cast(DataType::Boolean);
323-
return to_aexpr_impl(expr, arena, state);
324-
}
325-
},
326-
_ => {},
327-
}
328-
329-
let e = to_expr_irs(input, arena)?;
330-
331-
if state.output_name.is_none() {
332-
// Handles special case functions like `struct.field`.
333-
if let Some(name) = function.output_name() {
334-
state.output_name = name
335-
} else {
336-
set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function)));
337-
}
338-
}
339-
AExpr::Function {
340-
input: e,
341-
function,
342-
options,
343-
}
344-
},
300+
} => return convert_functions(input, function, options, arena, state),
345301
Expr::Window {
346302
function,
347303
partition_by,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use arrow::legacy::error::PolarsResult;
2+
use polars_utils::arena::{Arena, Node};
3+
use polars_utils::format_pl_smallstr;
4+
5+
use super::*;
6+
use crate::dsl::{Expr, FunctionExpr};
7+
use crate::plans::AExpr;
8+
use crate::prelude::FunctionOptions;
9+
10+
pub(super) fn convert_functions(
11+
input: Vec<Expr>,
12+
function: FunctionExpr,
13+
options: FunctionOptions,
14+
arena: &mut Arena<AExpr>,
15+
state: &mut ConversionContext,
16+
) -> PolarsResult<Node> {
17+
match function {
18+
// This can be created by col(*).is_null() on empty dataframes.
19+
FunctionExpr::Boolean(BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal)
20+
if input.is_empty() =>
21+
{
22+
return to_aexpr_impl(lit(true), arena, state);
23+
},
24+
// Convert to binary expression as the optimizer understands those.
25+
// Don't exceed 128 expressions as we might stackoverflow.
26+
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
27+
if input.len() < 128 {
28+
let expr = input
29+
.into_iter()
30+
.reduce(|l, r| l.logical_and(r))
31+
.unwrap()
32+
.cast(DataType::Boolean);
33+
return to_aexpr_impl(expr, arena, state);
34+
}
35+
},
36+
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
37+
if input.len() < 128 {
38+
let expr = input
39+
.into_iter()
40+
.reduce(|l, r| l.logical_or(r))
41+
.unwrap()
42+
.cast(DataType::Boolean);
43+
return to_aexpr_impl(expr, arena, state);
44+
}
45+
},
46+
_ => {},
47+
}
48+
49+
let e = to_expr_irs(input, arena)?;
50+
51+
if state.output_name.is_none() {
52+
// Handles special case functions like `struct.field`.
53+
if let Some(name) = function.output_name() {
54+
state.output_name = name
55+
} else {
56+
set_function_output_name(&e, state, || format_pl_smallstr!("{}", &function));
57+
}
58+
}
59+
60+
let ae_function = AExpr::Function {
61+
input: e,
62+
function,
63+
options,
64+
};
65+
Ok(arena.add(ae_function))
66+
}

crates/polars-plan/src/plans/conversion/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ mod ir_to_dsl;
1212
mod scans;
1313
mod stack_opt;
1414

15-
use std::borrow::Cow;
1615
use std::sync::{Arc, Mutex, RwLock};
1716

1817
pub use dsl_to_ir::*;
@@ -21,6 +20,7 @@ pub use ir_to_dsl::*;
2120
use polars_core::prelude::*;
2221
use polars_utils::vec::ConvertVec;
2322
use recursive::recursive;
23+
mod functions;
2424
pub(crate) mod type_coercion;
2525

2626
pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection};

crates/polars-plan/src/plans/conversion/type_coercion/binary.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,13 @@ pub(super) fn process_binary(
296296
st = String
297297
}
298298

299-
// only cast if the type is not already the super type.
299+
// TODO! raise here?
300+
// We should at least never cast to Unknown.
301+
if matches!(st, DataType::Unknown(UnknownKind::Any)) {
302+
return Ok(None);
303+
}
304+
305+
// Only cast if the type is not already the super type.
300306
// this can prevent an expensive flattening and subsequent aggregation
301307
// in a group_by context. To be able to cast the groups need to be
302308
// flattened
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use either::Either;
2+
3+
use super::*;
4+
5+
pub(super) fn get_function_dtypes(
6+
input: &[ExprIR],
7+
expr_arena: &Arena<AExpr>,
8+
input_schema: &Schema,
9+
function: &FunctionExpr,
10+
mut options: FunctionOptions,
11+
) -> PolarsResult<Either<Vec<DataType>, AExpr>> {
12+
let mut early_return = move || {
13+
// Next iteration this will not hit anymore as options is updated.
14+
options.cast_to_supertypes = None;
15+
Ok(Either::Right(AExpr::Function {
16+
function: function.clone(),
17+
input: input.to_vec(),
18+
options,
19+
}))
20+
};
21+
22+
let mut dtypes = Vec::with_capacity(input.len());
23+
let mut first = true;
24+
for e in input {
25+
let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else {
26+
return early_return();
27+
};
28+
29+
if first {
30+
check_namespace(function, &dtype)?;
31+
first = false;
32+
}
33+
// Ignore Unknown in the inputs.
34+
// We will raise if we cannot find the supertype later.
35+
match dtype {
36+
DataType::Unknown(UnknownKind::Any) => {
37+
return early_return();
38+
},
39+
_ => dtypes.push(dtype),
40+
}
41+
}
42+
43+
if dtypes.iter().all_equal() {
44+
return early_return();
45+
}
46+
Ok(Either::Left(dtypes))
47+
}
48+
49+
// `str` namespace belongs to `String`
50+
// `cat` namespace belongs to `Categorical` etc.
51+
fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> {
52+
match function {
53+
#[cfg(feature = "strings")]
54+
FunctionExpr::StringExpr(_) => {
55+
polars_ensure!(first_dtype == &DataType::String, InvalidOperation: "expected String type, got: {}", first_dtype)
56+
},
57+
FunctionExpr::BinaryExpr(_) => {
58+
polars_ensure!(first_dtype == &DataType::Binary, InvalidOperation: "expected Binary type, got: {}", first_dtype)
59+
},
60+
#[cfg(feature = "temporal")]
61+
FunctionExpr::TemporalExpr(_) => {
62+
polars_ensure!(first_dtype.is_temporal(), InvalidOperation: "expected Date(time)/Duration type, got: {}", first_dtype)
63+
},
64+
FunctionExpr::ListExpr(_) => {
65+
polars_ensure!(matches!(first_dtype, DataType::List(_)), InvalidOperation: "expected List type, got: {}", first_dtype)
66+
},
67+
#[cfg(feature = "dtype-array")]
68+
FunctionExpr::ArrayExpr(_) => {
69+
polars_ensure!(matches!(first_dtype, DataType::Array(_, _)), InvalidOperation: "expected Array type, got: {}", first_dtype)
70+
},
71+
#[cfg(feature = "dtype-struct")]
72+
FunctionExpr::StructExpr(_) => {
73+
polars_ensure!(matches!(first_dtype, DataType::Struct(_)), InvalidOperation: "expected Struct type, got: {}", first_dtype)
74+
},
75+
#[cfg(feature = "dtype-categorical")]
76+
FunctionExpr::Categorical(_) => {
77+
polars_ensure!(matches!(first_dtype, DataType::Categorical(_, _)), InvalidOperation: "expected Struct type, got: {}", first_dtype)
78+
},
79+
_ => {},
80+
}
81+
82+
Ok(())
83+
}

0 commit comments

Comments
 (0)