Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ syntax = 'proto3';
package spark.connect;

import "spark/connect/expressions.proto";
import "spark/connect/types.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand Down Expand Up @@ -305,6 +306,17 @@ message LocalRelation {
// Local collection data serialized into Arrow IPC streaming format which contains
// the schema of the data.
bytes data = 1;

// (Optional) The user provided schema.
//
// The Sever side will update the column names and data types according to this schema.
oneof schema {

DataType datatype = 2;

// Server will use Catalyst parser to parse this string to DataType.
string datatype_str = 3;
Copy link
Member

@HyukjinKwon HyukjinKwon Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can always pass string implementation for now (by turning DataType to a JSON representation), DataType.json()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm think adding support for _parse_datatype_string in AnalyzePlan, then we don't need to add datatype and datatype_str in LocalRelation at all.

Then the implementation will be like this (after we implement DataFrame.to):

schema = _parse_datatype_string(schema_str)
return DataFrame.withPlan(LocalRelation(table), self).toDF(*schema.fieldNames).to(schema)

}
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,21 @@ class SparkConnectPlanner(session: SparkSession) {
}
}

private def parseDatatypeString(sqlText: String): DataType = {
val parser = session.sessionState.sqlParser
try {
parser.parseTableSchema(sqlText)
} catch {
case _: ParseException =>
try {
parser.parseDataType(sqlText)
} catch {
case _: ParseException =>
parser.parseDataType(s"struct<${sqlText.trim}>")
}
}
}

private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
Iterator(rel.getData.toByteArray),
Expand All @@ -380,7 +395,28 @@ class SparkConnectPlanner(session: SparkSession) {
}
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)
val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)

if (!rel.hasDatatype && !rel.hasDatatypeStr) {
return relation
}

val schemaType = if (rel.hasDatatype) {
DataTypeProtoConverter.toCatalystType(rel.getDatatype)
} else {
parseDatatypeString(rel.getDatatypeStr)
}

val schemaStruct = schemaType match {
case s: StructType => s
case d => StructType(Seq(StructField("value", d)))
}

Dataset
.ofRows(session, logicalPlan = relation)
.toDF(schemaStruct.names: _*)
.to(schemaStruct)
.logicalPlan
}

private def transformReadRel(rel: proto.Read): LogicalPlan = {
Expand Down
31 changes: 23 additions & 8 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
import functools
import pandas
import pyarrow as pa

from pyspark.sql.types import DataType

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column, SortOrder, ColumnReference

from pyspark.sql.connect.types import pyspark_types_to_proto_types

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -167,21 +169,34 @@ def _repr_html_(self) -> str:


class LocalRelation(LogicalPlan):
"""Creates a LocalRelation plan object based on a Pandas DataFrame."""
"""Creates a LocalRelation plan object based on a PyArrow Table."""

def __init__(self, pdf: "pandas.DataFrame") -> None:
def __init__(
self,
table: "pa.Table",
schema: Optional[Union[DataType, str]] = None,
) -> None:
super().__init__(None)
self._pdf = pdf
assert table is not None and isinstance(table, pa.Table)
self._table = table

if schema is not None:
assert isinstance(schema, (DataType, str))
self._schema = schema

def plan(self, session: "SparkConnectClient") -> proto.Relation:
sink = pa.BufferOutputStream()
table = pa.Table.from_pandas(self._pdf)
with pa.ipc.new_stream(sink, table.schema) as writer:
for b in table.to_batches():
with pa.ipc.new_stream(sink, self._table.schema) as writer:
for b in self._table.to_batches():
writer.write_batch(b)

plan = proto.Relation()
plan.local_relation.data = sink.getvalue().to_pybytes()
if self._schema is not None:
if isinstance(self._schema, DataType):
plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyspark_types_to_proto_types does not support StructType now.
I'm going to fix it in a separate PR.

elif isinstance(self._schema, str):
plan.local_relation.datatype_str = self._schema
return plan

def print(self, indent: int = 0) -> str:
Expand Down
175 changes: 88 additions & 87 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

32 changes: 31 additions & 1 deletion python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import pyspark.sql.connect.proto.expressions_pb2
import pyspark.sql.connect.proto.types_pb2
import sys
import typing

Expand Down Expand Up @@ -1168,16 +1169,45 @@ class LocalRelation(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

DATA_FIELD_NUMBER: builtins.int
DATATYPE_FIELD_NUMBER: builtins.int
DATATYPE_STR_FIELD_NUMBER: builtins.int
data: builtins.bytes
"""Local collection data serialized into Arrow IPC streaming format which contains
the schema of the data.
"""
@property
def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
datatype_str: builtins.str
"""Server will use Catalyst parser to parse this string to DataType."""
def __init__(
self,
*,
data: builtins.bytes = ...,
datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
datatype_str: builtins.str = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"datatype", b"datatype", "datatype_str", b"datatype_str", "schema", b"schema"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"data",
b"data",
"datatype",
b"datatype",
"datatype_str",
b"datatype_str",
"schema",
b"schema",
],
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["data", b"data"]) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["schema", b"schema"]
) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ...

global___LocalRelation = LocalRelation

Expand Down
114 changes: 102 additions & 12 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,35 @@
#

from threading import RLock
from typing import Optional, Any, Union, Dict, cast, overload
from collections.abc import Sized

import numpy as np
import pandas as pd
import pyarrow as pa

from pyspark.sql.types import DataType, StructType

import pyspark.sql.types
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import SQL, Range
from pyspark.sql.connect.plan import SQL, Range, LocalRelation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.utils import to_str
from . import plan
from ._typing import OptionalPrimitiveType

from typing import (
Optional,
Any,
Union,
Dict,
List,
Tuple,
cast,
overload,
Iterable,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType


# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped
Expand Down Expand Up @@ -240,7 +258,11 @@ def read(self) -> "DataFrameReader":
"""
return DataFrameReader(self)

def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":
def createDataFrame(
self,
data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = None,
) -> "DataFrame":
"""
Creates a :class:`DataFrame` from a :class:`pandas.DataFrame`.

Expand All @@ -249,7 +271,15 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":

Parameters
----------
data : :class:`pandas.DataFrame`
data : :class:`pandas.DataFrame` or :class:`list`, or :class:`numpy.ndarray`.
schema : :class:`pyspark.sql.types.DataType`, str or list, optional

When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must
match the real data, or an exception will be thrown at runtime. If the given schema is
not :class:`pyspark.sql.types.StructType`, it will be wrapped into a
:class:`pyspark.sql.types.StructType` as its only field, and the field name will be
"value". Each record will also be wrapped into a tuple, which can be converted to row
later.

Returns
-------
Expand All @@ -264,9 +294,71 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":

"""
assert data is not None
if len(data) == 0:
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
if isinstance(data, Sized) and len(data) == 0:
raise ValueError("Input data cannot be empty")
return DataFrame.withPlan(plan.LocalRelation(data), self)

_schema: Optional[StructType] = None
_schema_str: Optional[str] = None
_cols: Optional[List[str]] = None

if isinstance(schema, StructType):
_schema = schema

elif isinstance(schema, str):
_schema_str = schema

elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with StructField names
_cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]

# Create the Pandas DataFrame
if isinstance(data, pd.DataFrame):
pdf = data

elif isinstance(data, np.ndarray):
# `data` of numpy.ndarray type will be converted to a pandas DataFrame,
if data.ndim not in [1, 2]:
raise ValueError("NumPy array input should be of 1 or 2 dimensions.")

pdf = pd.DataFrame(data)

if _cols is None:
if data.ndim == 1 or data.shape[1] == 1:
_cols = ["value"]
else:
_cols = ["_%s" % i for i in range(1, data.shape[1] + 1)]

else:
pdf = pd.DataFrame(list(data))

if _cols is None:
_cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)]

# Validate number of columns
num_cols = pdf.shape[1]
if _schema is not None and len(_schema.fields) != num_cols:
raise ValueError(
f"Length mismatch: Expected axis has {num_cols} elements, "
f"new values have {len(_schema.fields)} elements"
)
elif _cols is not None and len(_cols) != num_cols:
raise ValueError(
f"Length mismatch: Expected axis has {num_cols} elements, "
f"new values have {len(_cols)} elements"
)

table = pa.Table.from_pandas(pdf)

if _schema is not None:
return DataFrame.withPlan(LocalRelation(table, schema=_schema), self)
elif _schema_str is not None:
return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can have a RPC for parseTableSchema in AnalyzePlan and implement DataFrame.to, then we do not need to add schema in LocalRelation's proto, and simplify here with DataFrame.withPlan(LocalRelation(table), self).toDF(...).to(...)

elif _cols is not None and len(_cols) > 0:
return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols)
else:
return DataFrame.withPlan(LocalRelation(table), self)

@property
def client(self) -> "SparkConnectClient":
Expand All @@ -279,9 +371,7 @@ def client(self) -> "SparkConnectClient":
"""
return self._client

def register_udf(
self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> str:
def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str:
return self._client.register_udf(function, return_type)

def sql(self, sql_string: str) -> "DataFrame":
Expand Down
Loading