diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index e8d17924b9..60798de42d 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -504,7 +504,7 @@ def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata: ) -def combine_metadata( # noqa: C901, PLR0912 +def combine_metadata( *args: IntoExpr | object | None, str_as_lit: bool, allow_multi_output: bool, @@ -526,17 +526,17 @@ def combine_metadata( # noqa: C901, PLR0912 # result preserves length if at least one input does result_preserves_length = False # result is elementwise if all inputs are elementwise - result_is_not_elementwise = False + result_is_elementwise = True # result is scalar-like if all inputs are scalar-like - result_is_not_scalar_like = False + result_is_scalar_like = True # result is literal if all inputs are literal - result_is_not_literal = False + result_is_literal = True - for i, arg in enumerate(args): # noqa: PLR1702 + for i, arg in enumerate(args): if (isinstance(arg, str) and not str_as_lit) or is_series(arg): result_preserves_length = True - result_is_not_scalar_like = True - result_is_not_literal = True + result_is_scalar_like = False + result_is_literal = False elif is_expr(arg): metadata = arg._metadata if metadata.expansion_kind.is_multi_output(): @@ -549,24 +549,19 @@ def combine_metadata( # noqa: C901, PLR0912 ) raise MultiOutputExpressionError(msg) if not to_single_output: - if i == 0: - result_expansion_kind = expansion_kind - else: - result_expansion_kind = result_expansion_kind & expansion_kind + result_expansion_kind = ( + result_expansion_kind & expansion_kind + if i > 0 + else expansion_kind + ) - if metadata.has_windows: - result_has_windows = True + result_has_windows |= metadata.has_windows result_n_orderable_ops += metadata.n_orderable_ops - if metadata.preserves_length: - result_preserves_length = True - if not metadata.is_elementwise: - result_is_not_elementwise = True - if not metadata.is_scalar_like: - result_is_not_scalar_like = True - if not metadata.is_literal: - result_is_not_literal = True - if metadata.is_filtration: - n_filtrations += 1 + result_preserves_length |= metadata.preserves_length + result_is_elementwise &= metadata.is_elementwise + result_is_scalar_like &= metadata.is_scalar_like + result_is_literal &= metadata.is_literal + n_filtrations += int(metadata.is_filtration) if n_filtrations > 1: msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation" @@ -581,9 +576,9 @@ def combine_metadata( # noqa: C901, PLR0912 has_windows=result_has_windows, n_orderable_ops=result_n_orderable_ops, preserves_length=result_preserves_length, - is_elementwise=not result_is_not_elementwise, - is_scalar_like=not result_is_not_scalar_like, - is_literal=not result_is_not_literal, + is_elementwise=result_is_elementwise, + is_scalar_like=result_is_scalar_like, + is_literal=result_is_literal, )