Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ message Relation {
Unpivot unpivot = 25;
ToSchema to_schema = 26;
RepartitionByExpression repartition_by_expression = 27;
SameSemantics same_semantics = 29;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -748,3 +749,11 @@ message ToSchema {
// (Optional) number of partitions, must be positive.
optional int32 num_partitions = 3;
}

message SameSemantics {
// (Required) The input relation.
Relation input = 1;

// (Required) The other Relation to compare against.
Relation other = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.SAME_SEMANTICS =>
transformSameSemantics(rel.getSameSemantics)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")

Expand Down Expand Up @@ -535,6 +537,17 @@ class SparkConnectPlanner(session: SparkSession) {
numPartitionsOpt)
}

private def transformSameSemantics(rel: proto.SameSemantics): LogicalPlan = {
val otherDS = Dataset
.ofRows(session, transformRelation(rel.getOther))
val sameSemantics = Dataset
.ofRows(session, transformRelation(rel.getInput))
.sameSemantics(otherDS)
LocalRelation.fromProduct(
output = AttributeReference("same_semantics", BooleanType, false)() :: Nil,
data = Tuple1.apply(sameSemantics) :: Nil)
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,15 @@ def _repr_html_(self, *args: Any, **kwargs: Any) -> None:
def semanticHash(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("semanticHash() is not implemented.")

def sameSemantics(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("sameSemantics() is not implemented.")
def sameSemantics(self, other: "DataFrame") -> bool:
pdf = DataFrame.withPlan(
plan.SameSemantics(child=self._plan, other=other._plan),
session=self._session,
).toPandas()
assert pdf is not None
return pdf["same_semantics"][0]

sameSemantics.__doc__ = PySparkDataFrame.sameSemantics.__doc__

def writeTo(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("writeTo() is not implemented.")
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,20 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class SameSemantics(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], other: Optional["LogicalPlan"]) -> None:
super().__init__(child)
self.other = other

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None and self.other is not None
plan = proto.Relation()
plan.same_semantics.input.CopyFrom(self._child.plan(session))
plan.same_semantics.other.CopyFrom(self.other.plan(session))

return plan


class Unpivot(LogicalPlan):
"""Logical plan object for a unpivot operation."""

Expand Down
214 changes: 114 additions & 100 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message):
UNPIVOT_FIELD_NUMBER: builtins.int
TO_SCHEMA_FIELD_NUMBER: builtins.int
REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
SAME_SEMANTICS_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message):
@property
def repartition_by_expression(self) -> global___RepartitionByExpression: ...
@property
def same_semantics(self) -> global___SameSemantics: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message):
unpivot: global___Unpivot | None = ...,
to_schema: global___ToSchema | None = ...,
repartition_by_expression: global___RepartitionByExpression | None = ...,
same_semantics: global___SameSemantics | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -297,6 +301,8 @@ class Relation(google.protobuf.message.Message):
b"repartition_by_expression",
"replace",
b"replace",
"same_semantics",
b"same_semantics",
"sample",
b"sample",
"sample_by",
Expand Down Expand Up @@ -386,6 +392,8 @@ class Relation(google.protobuf.message.Message):
b"repartition_by_expression",
"replace",
b"replace",
"same_semantics",
b"same_semantics",
"sample",
b"sample",
"sample_by",
Expand Down Expand Up @@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message):
"unpivot",
"to_schema",
"repartition_by_expression",
"same_semantics",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -2552,3 +2561,29 @@ class RepartitionByExpression(google.protobuf.message.Message):
) -> typing_extensions.Literal["num_partitions"] | None: ...

global___RepartitionByExpression = RepartitionByExpression

class SameSemantics(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
OTHER_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
@property
def other(self) -> global___Relation:
"""(Required) The other Relation to compare against."""
def __init__(
self,
*,
input: global___Relation | None = ...,
other: global___Relation | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input", "other", b"other"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["input", b"input", "other", b"other"]
) -> None: ...

global___SameSemantics = SameSemantics
3 changes: 3 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4864,6 +4864,9 @@ def sameSemantics(self, other: "DataFrame") -> bool:

.. versionadded:: 3.1.0

.. versionchanged:: 3.4.0
Support Spark Connect.

Notes
-----
The equality comparison here is simplified by tolerating the cosmetic differences
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,12 +1988,19 @@ def test_unsupported_functions(self):
"localCheckpoint",
"_repr_html_",
"semanticHash",
"sameSemantics",
"writeTo",
):
with self.assertRaises(NotImplementedError):
getattr(df, f)()

def test_same_semantics(self):
df1 = self.connect.read.table(self.tbl_name).limit(10)
df2 = self.connect.read.table(self.tbl_name).limit(10)
df3 = self.connect.read.table(self.tbl_name).limit(1)

self.assertTrue(df1.sameSemantics(df2))
self.assertFalse(df1.sameSemantics(df3))


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ChannelBuilderTests(ReusedPySparkTestCase):
Expand Down