Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,16 @@ message Deduplicate {

// A relation that does not need to be qualified by name.
message LocalRelation {
// Local collection data serialized into Arrow IPC streaming format which contains
// (Optional) Local collection data serialized into Arrow IPC streaming format which contains
// the schema of the data.
bytes data = 1;
optional bytes data = 1;

// (Optional) The user provided schema.
// (Optional) The schema of local data.
// It should be either a DDL-formatted type string or a JSON string.
//
// The Sever side will update the column names and data types according to this schema.
oneof schema {

DataType datatype = 2;

// Server will use Catalyst parser to parse this string to DataType.
string datatype_str = 3;
}
// The server side will update the column names and data types according to this schema.
// If the 'data' is not provided, then this schema will be required.
optional string schema = 2;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,47 +571,61 @@ class SparkConnectPlanner(session: SparkSession) {
try {
parser.parseTableSchema(sqlText)
} catch {
case _: ParseException =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refer to #38979 (comment)

case e: ParseException =>
try {
parser.parseDataType(sqlText)
} catch {
case _: ParseException =>
parser.parseDataType(s"struct<${sqlText.trim}>")
try {
parser.parseDataType(s"struct<${sqlText.trim}>")
} catch {
case _: ParseException =>
throw e
}
}
}
}

private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
Iterator(rel.getData.toByteArray),
TaskContext.get())
if (structType == null) {
throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.")
var schema: StructType = null
if (rel.hasSchema) {
val schemaType = DataType.parseTypeWithFallback(
rel.getSchema,
parseDatatypeString,
fallbackParser = DataType.fromJson)
schema = schemaType match {
case s: StructType => s
case d => StructType(Seq(StructField("value", d)))
}
}
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)

if (!rel.hasDatatype && !rel.hasDatatypeStr) {
return relation
}
if (rel.hasData) {
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
Iterator(rel.getData.toByteArray),
TaskContext.get())
if (structType == null) {
throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.")
}
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)

val schemaType = if (rel.hasDatatype) {
DataTypeProtoConverter.toCatalystType(rel.getDatatype)
if (schema == null) {
relation
} else {
Dataset
.ofRows(session, logicalPlan = relation)
.toDF(schema.names: _*)
.to(schema)
.logicalPlan
}
} else {
parseDatatypeString(rel.getDatatypeStr)
}

val schemaStruct = schemaType match {
case s: StructType => s
case d => StructType(Seq(StructField("value", d)))
if (schema == null) {
throw InvalidPlanInput(
s"Schema for LocalRelation is required when the input data is not provided.")
}
LocalRelation(schema.toAttributes, data = Seq.empty)
}

Dataset
.ofRows(session, logicalPlan = relation)
.toDF(schemaStruct.names: _*)
.to(schemaStruct)
.logicalPlan
}

private def transformReadRel(rel: proto.Read): LogicalPlan = {
Expand Down
3 changes: 0 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,9 +1426,6 @@ def _test() -> None:
# TODO(SPARK-41827): groupBy requires all cols be Column or str
del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__

# TODO(SPARK-41828): Implement creating empty DataFrame
del pyspark.sql.connect.dataframe.DataFrame.isEmpty.__doc__

# TODO(SPARK-41829): Add Dataframe sort ordering
del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__
del pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__
Expand Down
34 changes: 19 additions & 15 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,30 +270,34 @@ class LocalRelation(LogicalPlan):

def __init__(
self,
table: "pa.Table",
schema: Optional[Union[DataType, str]] = None,
table: Optional["pa.Table"],
schema: Optional[str] = None,
) -> None:
super().__init__(None)
assert table is not None and isinstance(table, pa.Table)

if table is None:
assert schema is not None
else:
assert isinstance(table, pa.Table)

assert schema is None or isinstance(schema, str)

self._table = table

if schema is not None:
assert isinstance(schema, (DataType, str))
self._schema = schema

def plan(self, session: "SparkConnectClient") -> proto.Relation:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
for b in self._table.to_batches():
writer.write_batch(b)

plan = proto.Relation()
plan.local_relation.data = sink.getvalue().to_pybytes()

if self._table is not None:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
for b in self._table.to_batches():
writer.write_batch(b)
plan.local_relation.data = sink.getvalue().to_pybytes()

if self._schema is not None:
if isinstance(self._schema, DataType):
plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema))
elif isinstance(self._schema, str):
plan.local_relation.datatype_str = self._schema
plan.local_relation.schema = self._schema
return plan

def print(self, indent: int = 0) -> str:
Expand Down
Loading