-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41446][CONNECT][PYTHON] Make createDataFrame support schema and more input dataset types
#38979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-41446][CONNECT][PYTHON] Make createDataFrame support schema and more input dataset types
#38979
Changes from all commits
abb5936
1479737
02156d4
21a7a5c
5b09875
c599bb6
9002775
f79b87f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,13 @@ | |
|
|
||
| from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict | ||
| import functools | ||
| import pandas | ||
| import pyarrow as pa | ||
|
|
||
| from pyspark.sql.types import DataType | ||
|
|
||
| import pyspark.sql.connect.proto as proto | ||
| from pyspark.sql.connect.column import Column, SortOrder, ColumnReference | ||
|
|
||
| from pyspark.sql.connect.types import pyspark_types_to_proto_types | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pyspark.sql.connect._typing import ColumnOrName | ||
|
|
@@ -167,21 +169,34 @@ def _repr_html_(self) -> str: | |
|
|
||
|
|
||
| class LocalRelation(LogicalPlan): | ||
| """Creates a LocalRelation plan object based on a Pandas DataFrame.""" | ||
| """Creates a LocalRelation plan object based on a PyArrow Table.""" | ||
|
|
||
| def __init__(self, pdf: "pandas.DataFrame") -> None: | ||
| def __init__( | ||
| self, | ||
| table: "pa.Table", | ||
| schema: Optional[Union[DataType, str]] = None, | ||
| ) -> None: | ||
| super().__init__(None) | ||
| self._pdf = pdf | ||
| assert table is not None and isinstance(table, pa.Table) | ||
| self._table = table | ||
|
|
||
| if schema is not None: | ||
| assert isinstance(schema, (DataType, str)) | ||
| self._schema = schema | ||
|
|
||
| def plan(self, session: "SparkConnectClient") -> proto.Relation: | ||
| sink = pa.BufferOutputStream() | ||
| table = pa.Table.from_pandas(self._pdf) | ||
| with pa.ipc.new_stream(sink, table.schema) as writer: | ||
| for b in table.to_batches(): | ||
| with pa.ipc.new_stream(sink, self._table.schema) as writer: | ||
| for b in self._table.to_batches(): | ||
| writer.write_batch(b) | ||
|
|
||
| plan = proto.Relation() | ||
| plan.local_relation.data = sink.getvalue().to_pybytes() | ||
| if self._schema is not None: | ||
| if isinstance(self._schema, DataType): | ||
| plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema)) | ||
|
||
| elif isinstance(self._schema, str): | ||
| plan.local_relation.datatype_str = self._schema | ||
| return plan | ||
|
|
||
| def print(self, indent: int = 0) -> str: | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,17 +16,35 @@ | |
| # | ||
|
|
||
| from threading import RLock | ||
| from typing import Optional, Any, Union, Dict, cast, overload | ||
| from collections.abc import Sized | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import pyarrow as pa | ||
|
|
||
| from pyspark.sql.types import DataType, StructType | ||
|
|
||
| import pyspark.sql.types | ||
| from pyspark.sql.connect.client import SparkConnectClient | ||
| from pyspark.sql.connect.dataframe import DataFrame | ||
| from pyspark.sql.connect.plan import SQL, Range | ||
| from pyspark.sql.connect.plan import SQL, Range, LocalRelation | ||
| from pyspark.sql.connect.readwriter import DataFrameReader | ||
| from pyspark.sql.utils import to_str | ||
| from . import plan | ||
| from ._typing import OptionalPrimitiveType | ||
|
|
||
| from typing import ( | ||
| Optional, | ||
| Any, | ||
| Union, | ||
| Dict, | ||
| List, | ||
| Tuple, | ||
| cast, | ||
| overload, | ||
| Iterable, | ||
| TYPE_CHECKING, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pyspark.sql.connect._typing import OptionalPrimitiveType | ||
|
|
||
|
|
||
| # TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped | ||
|
|
@@ -240,7 +258,11 @@ def read(self) -> "DataFrameReader": | |
| """ | ||
| return DataFrameReader(self) | ||
|
|
||
| def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": | ||
| def createDataFrame( | ||
| self, | ||
| data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]], | ||
| schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = None, | ||
| ) -> "DataFrame": | ||
| """ | ||
| Creates a :class:`DataFrame` from a :class:`pandas.DataFrame`. | ||
|
|
||
|
|
@@ -249,7 +271,15 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": | |
|
|
||
| Parameters | ||
| ---------- | ||
| data : :class:`pandas.DataFrame` | ||
| data : :class:`pandas.DataFrame` or :class:`list`, or :class:`numpy.ndarray`. | ||
| schema : :class:`pyspark.sql.types.DataType`, str or list, optional | ||
|
|
||
| When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must | ||
| match the real data, or an exception will be thrown at runtime. If the given schema is | ||
| not :class:`pyspark.sql.types.StructType`, it will be wrapped into a | ||
| :class:`pyspark.sql.types.StructType` as its only field, and the field name will be | ||
| "value". Each record will also be wrapped into a tuple, which can be converted to row | ||
| later. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -264,9 +294,71 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": | |
|
|
||
| """ | ||
| assert data is not None | ||
| if len(data) == 0: | ||
| if isinstance(data, DataFrame): | ||
| raise TypeError("data is already a DataFrame") | ||
| if isinstance(data, Sized) and len(data) == 0: | ||
| raise ValueError("Input data cannot be empty") | ||
| return DataFrame.withPlan(plan.LocalRelation(data), self) | ||
|
|
||
| _schema: Optional[StructType] = None | ||
| _schema_str: Optional[str] = None | ||
| _cols: Optional[List[str]] = None | ||
|
|
||
| if isinstance(schema, StructType): | ||
| _schema = schema | ||
|
|
||
| elif isinstance(schema, str): | ||
| _schema_str = schema | ||
|
|
||
| elif isinstance(schema, (list, tuple)): | ||
| # Must re-encode any unicode strings to be consistent with StructField names | ||
| _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] | ||
|
|
||
| # Create the Pandas DataFrame | ||
| if isinstance(data, pd.DataFrame): | ||
| pdf = data | ||
|
|
||
| elif isinstance(data, np.ndarray): | ||
| # `data` of numpy.ndarray type will be converted to a pandas DataFrame, | ||
| if data.ndim not in [1, 2]: | ||
| raise ValueError("NumPy array input should be of 1 or 2 dimensions.") | ||
|
|
||
| pdf = pd.DataFrame(data) | ||
|
|
||
| if _cols is None: | ||
| if data.ndim == 1 or data.shape[1] == 1: | ||
| _cols = ["value"] | ||
| else: | ||
| _cols = ["_%s" % i for i in range(1, data.shape[1] + 1)] | ||
|
|
||
| else: | ||
| pdf = pd.DataFrame(list(data)) | ||
|
|
||
| if _cols is None: | ||
| _cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)] | ||
|
|
||
| # Validate number of columns | ||
| num_cols = pdf.shape[1] | ||
| if _schema is not None and len(_schema.fields) != num_cols: | ||
| raise ValueError( | ||
| f"Length mismatch: Expected axis has {num_cols} elements, " | ||
| f"new values have {len(_schema.fields)} elements" | ||
| ) | ||
| elif _cols is not None and len(_cols) != num_cols: | ||
| raise ValueError( | ||
| f"Length mismatch: Expected axis has {num_cols} elements, " | ||
| f"new values have {len(_cols)} elements" | ||
| ) | ||
|
|
||
| table = pa.Table.from_pandas(pdf) | ||
|
|
||
| if _schema is not None: | ||
| return DataFrame.withPlan(LocalRelation(table, schema=_schema), self) | ||
| elif _schema_str is not None: | ||
| return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we can have a RPC for |
||
| elif _cols is not None and len(_cols) > 0: | ||
| return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols) | ||
| else: | ||
| return DataFrame.withPlan(LocalRelation(table), self) | ||
|
|
||
| @property | ||
| def client(self) -> "SparkConnectClient": | ||
|
|
@@ -279,9 +371,7 @@ def client(self) -> "SparkConnectClient": | |
| """ | ||
| return self._client | ||
|
|
||
| def register_udf( | ||
| self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] | ||
| ) -> str: | ||
| def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str: | ||
| return self._client.register_udf(function, return_type) | ||
|
|
||
| def sql(self, sql_string: str) -> "DataFrame": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we can always pass string implementation for now (by turning DataType to a JSON representation),
DataType.json()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm think adding support for
_parse_datatype_stringin AnalyzePlan, then we don't need to adddatatypeanddatatype_strinLocalRelationat all.Then the implementation will be like this (after we implement
DataFrame.to):