-
Notifications
You must be signed in to change notification settings - Fork 179
chore: Finalize support for SQLFrame #2038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
700532a
WIP
FBruzzesi 0383ee0
WIP
FBruzzesi 8d48013
ruff
FBruzzesi ddd4797
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi 5735fb5
getting there
FBruzzesi ec0bb9a
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi c1aa784
duplicate column names check for pyspark
FBruzzesi 78d22fb
xfail unary with 2 elements
FBruzzesi 48ed576
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi 586f520
good stuff
FBruzzesi 09e46ca
mypy, missing deps
FBruzzesi 0b84b76
rm direct pyspark import
FBruzzesi 2425798
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi 3d6f0c0
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi f67e767
merge main, fix join
FBruzzesi 0570287
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi 0762bc2
Merge branch 'main' into chore/better-sqlframe-support
FBruzzesi 8662d02
ignore numpy deprecation warning
FBruzzesi 493c74d
merge main
FBruzzesi 6cfb2a9
mypy
FBruzzesi 8136110
solve conflicts
FBruzzesi 5532ba3
rm comment regarding the reason for not enabling sqlframe
FBruzzesi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from importlib import import_module | ||
| from typing import TYPE_CHECKING | ||
| from typing import Any | ||
| from typing import Literal | ||
|
|
@@ -13,6 +14,7 @@ | |
| from narwhals.typing import CompliantLazyFrame | ||
| from narwhals.utils import Implementation | ||
| from narwhals.utils import check_column_exists | ||
| from narwhals.utils import check_column_names_are_unique | ||
| from narwhals.utils import find_stacklevel | ||
| from narwhals.utils import import_dtypes_module | ||
| from narwhals.utils import parse_columns_to_drop | ||
|
|
@@ -23,6 +25,7 @@ | |
| from types import ModuleType | ||
|
|
||
| import pyarrow as pa | ||
| from pyspark.sql import Column | ||
| from pyspark.sql import DataFrame | ||
| from typing_extensions import Self | ||
|
|
||
|
|
@@ -41,7 +44,10 @@ def __init__( | |
| backend_version: tuple[int, ...], | ||
| version: Version, | ||
| implementation: Implementation, | ||
| validate_column_names: bool, | ||
| ) -> None: | ||
| if validate_column_names: | ||
| check_column_names_are_unique(native_dataframe.columns) | ||
| self._native_frame = native_dataframe | ||
| self._backend_version = backend_version | ||
| self._implementation = implementation | ||
|
|
@@ -51,33 +57,50 @@ def __init__( | |
| @property | ||
| def _F(self: Self) -> Any: # noqa: N802 | ||
| if self._implementation is Implementation.SQLFRAME: | ||
| from sqlframe.duckdb import functions | ||
| from sqlframe.base.session import _BaseSession | ||
|
|
||
| return import_module( | ||
| f"sqlframe.{_BaseSession().execution_dialect_name}.functions" | ||
| ) | ||
|
|
||
| return functions | ||
| from pyspark.sql import functions | ||
|
|
||
| return functions | ||
|
|
||
| @property | ||
| def _native_dtypes(self: Self) -> Any: | ||
| if self._implementation is Implementation.SQLFRAME: | ||
| from sqlframe.duckdb import types | ||
| from sqlframe.base.session import _BaseSession | ||
|
|
||
| return import_module( | ||
| f"sqlframe.{_BaseSession().execution_dialect_name}.types" | ||
| ) | ||
|
|
||
| return types | ||
| from pyspark.sql import types | ||
|
|
||
| return types | ||
|
|
||
| @property | ||
| def _Window(self: Self) -> Any: # noqa: N802 | ||
| if self._implementation is Implementation.SQLFRAME: | ||
| from sqlframe.duckdb import Window | ||
| from sqlframe.base.session import _BaseSession | ||
|
|
||
| _window = import_module( | ||
| f"sqlframe.{_BaseSession().execution_dialect_name}.window" | ||
| ) | ||
| return _window.Window | ||
|
|
||
| return Window | ||
| from pyspark.sql import Window | ||
|
|
||
| return Window | ||
|
|
||
| @property | ||
| def _session(self: Self) -> Any: | ||
| if self._implementation is Implementation.SQLFRAME: | ||
| return self._native_frame.session | ||
|
|
||
| return self._native_frame.sparkSession | ||
|
|
||
| def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover | ||
| return self._implementation.to_native_namespace() | ||
|
|
||
|
|
@@ -99,14 +122,18 @@ def _change_version(self: Self, version: Version) -> Self: | |
| backend_version=self._backend_version, | ||
| version=version, | ||
| implementation=self._implementation, | ||
| validate_column_names=False, | ||
| ) | ||
|
|
||
| def _from_native_frame(self: Self, df: DataFrame) -> Self: | ||
| def _from_native_frame( | ||
| self: Self, df: DataFrame, *, validate_column_names: bool = True | ||
| ) -> Self: | ||
| return self.__class__( | ||
| df, | ||
| backend_version=self._backend_version, | ||
| version=self._version, | ||
| implementation=self._implementation, | ||
| validate_column_names=validate_column_names, | ||
| ) | ||
|
|
||
| def _collect_to_arrow(self) -> pa.Table: | ||
|
|
@@ -205,7 +232,9 @@ def collect( | |
| raise ValueError(msg) # pragma: no cover | ||
|
|
||
| def simple_select(self: Self, *column_names: str) -> Self: | ||
| return self._from_native_frame(self._native_frame.select(*column_names)) | ||
| return self._from_native_frame( | ||
| self._native_frame.select(*column_names), validate_column_names=False | ||
| ) | ||
|
|
||
| def aggregate( | ||
| self: Self, | ||
|
|
@@ -214,7 +243,9 @@ def aggregate( | |
| new_columns = parse_exprs(self, *exprs) | ||
|
|
||
| new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] | ||
| return self._from_native_frame(self._native_frame.agg(*new_columns_list)) | ||
| return self._from_native_frame( | ||
| self._native_frame.agg(*new_columns_list), validate_column_names=False | ||
| ) | ||
|
|
||
| def select( | ||
| self: Self, | ||
|
|
@@ -224,17 +255,18 @@ def select( | |
|
|
||
| if not new_columns: | ||
| # return empty dataframe, like Polars does | ||
| spark_session = self._native_frame.sparkSession | ||
| spark_df = spark_session.createDataFrame( | ||
| spark_df = self._session.createDataFrame( | ||
| [], self._native_dtypes.StructType([]) | ||
| ) | ||
|
|
||
| return self._from_native_frame(spark_df) | ||
| return self._from_native_frame(spark_df, validate_column_names=False) | ||
|
|
||
| new_columns_list = [ | ||
| col.alias(col_name) for (col_name, col) in new_columns.items() | ||
| ] | ||
| return self._from_native_frame(self._native_frame.select(*new_columns_list)) | ||
| return self._from_native_frame( | ||
| self._native_frame.select(*new_columns_list), validate_column_names=False | ||
| ) | ||
|
|
||
| def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self: | ||
| new_columns = parse_exprs(self, *exprs) | ||
|
|
@@ -244,7 +276,7 @@ def filter(self: Self, predicate: SparkLikeExpr) -> Self: | |
| # `[0]` is safe as the predicate's expression only returns a single column | ||
| condition = predicate._call(self)[0] | ||
| spark_df = self._native_frame.where(condition) | ||
| return self._from_native_frame(spark_df) | ||
| return self._from_native_frame(spark_df, validate_column_names=False) | ||
|
|
||
| @property | ||
| def schema(self: Self) -> dict[str, DType]: | ||
|
|
@@ -264,13 +296,13 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 | |
| columns_to_drop = parse_columns_to_drop( | ||
| compliant_frame=self, columns=columns, strict=strict | ||
| ) | ||
| return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) | ||
| return self._from_native_frame( | ||
| self._native_frame.drop(*columns_to_drop), validate_column_names=False | ||
| ) | ||
|
|
||
| def head(self: Self, n: int) -> Self: | ||
| spark_session = self._native_frame.sparkSession | ||
|
|
||
| return self._from_native_frame( | ||
| spark_session.createDataFrame(self._native_frame.take(num=n)) | ||
| self._native_frame.limit(num=n), validate_column_names=False | ||
| ) | ||
|
|
||
| def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: | ||
|
|
@@ -301,10 +333,14 @@ def sort( | |
| ) | ||
|
|
||
| sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] | ||
| return self._from_native_frame(self._native_frame.sort(*sort_cols)) | ||
| return self._from_native_frame( | ||
| self._native_frame.sort(*sort_cols), validate_column_names=False | ||
| ) | ||
|
|
||
| def drop_nulls(self: Self, subset: list[str] | None) -> Self: | ||
| return self._from_native_frame(self._native_frame.dropna(subset=subset)) | ||
| return self._from_native_frame( | ||
| self._native_frame.dropna(subset=subset), validate_column_names=False | ||
| ) | ||
|
|
||
| def rename(self: Self, mapping: dict[str, str]) -> Self: | ||
| rename_mapping = { | ||
|
|
@@ -326,7 +362,9 @@ def unique( | |
| msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`." | ||
| raise ValueError(msg) | ||
| check_column_exists(self.columns, subset) | ||
| return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) | ||
| return self._from_native_frame( | ||
| self._native_frame.dropDuplicates(subset=subset), validate_column_names=False | ||
| ) | ||
|
|
||
| def join( | ||
| self: Self, | ||
|
|
@@ -357,7 +395,7 @@ def join( | |
| for colname in list(set(right_columns).difference(set(right_on or []))) | ||
| }, | ||
| } | ||
| other = other_native.select( | ||
| other_native = other_native.select( | ||
| [self._F.col(old).alias(new) for old, new in rename_mapping.items()] | ||
| ) | ||
|
|
||
|
|
@@ -375,7 +413,7 @@ def join( | |
| ] | ||
| ) | ||
| return self._from_native_frame( | ||
| self_native.join(other, on=left_on, how=how).select(col_order) | ||
| self_native.join(other_native, on=left_on, how=how).select(col_order) | ||
| ) | ||
|
|
||
| def explode(self: Self, columns: list[str]) -> Self: | ||
|
|
@@ -402,16 +440,51 @@ def explode(self: Self, columns: list[str]) -> Self: | |
| ) | ||
| raise NotImplementedError(msg) | ||
|
|
||
| return self._from_native_frame( | ||
| native_frame.select( | ||
| *[ | ||
| self._F.col(col_name).alias(col_name) | ||
| if col_name != columns[0] | ||
| else self._F.explode_outer(col_name).alias(col_name) | ||
| for col_name in column_names | ||
| ] | ||
| if self._implementation.is_pyspark(): | ||
| return self._from_native_frame( | ||
| native_frame.select( | ||
| *[ | ||
| self._F.col(col_name).alias(col_name) | ||
| if col_name != columns[0] | ||
| else self._F.explode_outer(col_name).alias(col_name) | ||
| for col_name in column_names | ||
| ] | ||
| ), | ||
| validate_column_names=False, | ||
| ) | ||
| ) | ||
| elif self._implementation.is_sqlframe(): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same workaround as per duckdb native. As order is not guaranteed, I think that's ok |
||
| # Not every sqlframe dialect supports `explode_outer` function | ||
| # (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289) | ||
| # therefore we simply explode the array column which will ignore nulls and | ||
| # zero sized arrays, and append these specific condition with nulls (to | ||
| # match polars behavior). | ||
|
|
||
| def null_condition(col_name: str) -> Column: | ||
| return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) | ||
|
|
||
| return self._from_native_frame( | ||
| native_frame.select( | ||
| *[ | ||
| self._F.col(col_name).alias(col_name) | ||
| if col_name != columns[0] | ||
| else self._F.explode(col_name).alias(col_name) | ||
| for col_name in column_names | ||
| ] | ||
| ).union( | ||
| native_frame.filter(null_condition(columns[0])).select( | ||
| *[ | ||
| self._F.col(col_name).alias(col_name) | ||
| if col_name != columns[0] | ||
| else self._F.lit(None).alias(col_name) | ||
| for col_name in column_names | ||
| ] | ||
| ) | ||
| ), | ||
| validate_column_names=False, | ||
| ) | ||
| else: # pragma: no cover | ||
| msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" | ||
| raise AssertionError(msg) | ||
|
|
||
| def unpivot( | ||
| self: Self, | ||
|
|
@@ -420,6 +493,15 @@ def unpivot( | |
| variable_name: str, | ||
| value_name: str, | ||
| ) -> Self: | ||
| if self._implementation.is_sqlframe(): | ||
| if variable_name == "": | ||
| msg = "`variable_name` cannot be empty string for sqlframe backend." | ||
| raise NotImplementedError(msg) | ||
|
|
||
| if value_name == "": | ||
| msg = "`value_name` cannot be empty string for sqlframe backend." | ||
| raise NotImplementedError(msg) | ||
|
|
||
| ids = tuple(self.columns) if index is None else tuple(index) | ||
| values = ( | ||
| tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on) | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While pyspark does it out of the box, sqlframe with duckdb backend doesn't.
I think it is good to have it regardless in order to standardize errors and error messages