diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index c7927e2014..a58c595aa8 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -26,6 +26,7 @@ from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation +from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -161,7 +162,7 @@ def all(self: Self) -> ArrowExpr: ], depth=0, function_name="all", - evaluate_output_names=lambda df: df.columns, + evaluate_output_names=get_column_names, alias_output_names=None, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 93b02aab4c..5f9fa982f6 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -23,6 +23,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace +from narwhals.utils import get_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -55,7 +56,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: func, depth=0, function_name="all", - evaluate_output_names=lambda df: df.columns, + evaluate_output_names=get_column_names, alias_output_names=None, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 439e22c69b..4ea72fae3b 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -24,6 +24,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace +from narwhals.utils import get_column_names if TYPE_CHECKING: import duckdb @@ -52,7 +53,7 @@ def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return DuckDBExpr( call=_all, function_name="all", - evaluate_output_names=lambda df: df.columns, + evaluate_output_names=get_column_names, alias_output_names=None, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index f369bb3094..4e1b887fc0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -22,6 +22,7 @@ from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat from narwhals.typing import CompliantNamespace +from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -135,7 +136,7 @@ def all(self: Self) -> PandasLikeExpr: ], depth=0, function_name="all", - evaluate_output_names=lambda df: df.columns, + evaluate_output_names=get_column_names, alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 1b68d0734a..74d0841431 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -17,6 +17,7 @@ from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantNamespace +from narwhals.utils import get_column_names if TYPE_CHECKING: from pyspark.sql import Column @@ -51,7 +52,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: return SparkLikeExpr( call=_all, function_name="all", - evaluate_output_names=lambda df: df.columns, + evaluate_output_names=get_column_names, alias_output_names=None, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/typing.py b/narwhals/typing.py index fe04c4703d..9c868694a7 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -62,6 +62,9 @@ def aggregate(self, *exprs: Any) -> Self: ... # `select` where all args are aggregations or literals # (so, no broadcasting is necessary). + @property + def columns(self) -> Sequence[str]: ... + class CompliantLazyFrame(Protocol): def __narwhals_lazyframe__(self) -> Self: ... @@ -73,6 +76,9 @@ def aggregate(self, *exprs: Any) -> Self: ... # `select` where all args are aggregations or literals # (so, no broadcasting is necessary). + @property + def columns(self) -> Sequence[str]: ... + CompliantFrameT_contra = TypeVar( "CompliantFrameT_contra", diff --git a/narwhals/utils.py b/narwhals/utils.py index da8eadb80e..9d286398c7 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -103,6 +103,10 @@ class _FullContext(_StoresImplementation, _LimitedContext, Protocol): # noqa: P - `_version` """ + class _StoresColumns(Protocol): + @property + def columns(self) -> Sequence[str]: ... + class Version(Enum): V1 = auto() @@ -1344,6 +1348,10 @@ def dtype_matches_time_unit_and_time_zone( ) +def get_column_names(frame: _StoresColumns, /) -> Sequence[str]: + return frame.columns + + def _hasattr_static(obj: Any, attr: str) -> bool: sentinel = object() return getattr_static(obj, attr, sentinel) is not sentinel