Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ message Relation {
Unpivot unpivot = 25;
ToSchema to_schema = 26;
RepartitionByExpression repartition_by_expression = 27;
SemanticHash semantic_hash = 28;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -748,3 +749,8 @@ message ToSchema {
// (Optional) number of partitions, must be positive.
optional int32 num_partitions = 3;
}

message SemanticHash {
// (Required) The input relation.
Relation input = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.SEMANTIC_HASH => transformSemanticHash(rel.getSemanticHash)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")

Expand Down Expand Up @@ -535,6 +536,15 @@ class SparkConnectPlanner(session: SparkSession) {
numPartitionsOpt)
}

private def transformSemanticHash(rel: proto.SemanticHash): LogicalPlan = {
val semanticHash = Dataset
.ofRows(session, transformRelation(rel.getInput))
.semanticHash()
LocalRelation.fromProduct(
output = AttributeReference("semantic_hash", IntegerType, false)() :: Nil,
data = Tuple1.apply(semanticHash) :: Nil)
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,8 +1363,15 @@ def toJSON(self, *args: Any, **kwargs: Any) -> None:
def _repr_html_(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("_repr_html_() is not implemented.")

def semanticHash(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("semanticHash() is not implemented.")
def semanticHash(self) -> int:
pdf = DataFrame.withPlan(
plan.SemanticHash(child=self._plan),
session=self._session,
).toPandas()
assert pdf is not None
return pdf["semantic_hash"][0]

semanticHash.__doc__ = PySparkDataFrame.semanticHash.__doc__

def sameSemantics(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("sameSemantics() is not implemented.")
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,17 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class SemanticHash(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"]) -> None:
super().__init__(child)

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.semantic_hash.input.CopyFrom(self._child.plan(session))
return plan


class Unpivot(LogicalPlan):
"""Logical plan object for a unpivot operation."""

Expand Down
214 changes: 114 additions & 100 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message):
UNPIVOT_FIELD_NUMBER: builtins.int
TO_SCHEMA_FIELD_NUMBER: builtins.int
REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
SEMANTIC_HASH_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message):
@property
def repartition_by_expression(self) -> global___RepartitionByExpression: ...
@property
def semantic_hash(self) -> global___SemanticHash: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message):
unpivot: global___Unpivot | None = ...,
to_schema: global___ToSchema | None = ...,
repartition_by_expression: global___RepartitionByExpression | None = ...,
semantic_hash: global___SemanticHash | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -301,6 +305,8 @@ class Relation(google.protobuf.message.Message):
b"sample",
"sample_by",
b"sample_by",
"semantic_hash",
b"semantic_hash",
"set_op",
b"set_op",
"show_string",
Expand Down Expand Up @@ -390,6 +396,8 @@ class Relation(google.protobuf.message.Message):
b"sample",
"sample_by",
b"sample_by",
"semantic_hash",
b"semantic_hash",
"set_op",
b"set_op",
"show_string",
Expand Down Expand Up @@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message):
"unpivot",
"to_schema",
"repartition_by_expression",
"semantic_hash",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -2552,3 +2561,22 @@ class RepartitionByExpression(google.protobuf.message.Message):
) -> typing_extensions.Literal["num_partitions"] | None: ...

global___RepartitionByExpression = RepartitionByExpression

class SemanticHash(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) The input relation."""
def __init__(
self,
*,
input: global___Relation | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["input", b"input"]) -> None: ...

global___SemanticHash = SemanticHash
3 changes: 3 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4906,6 +4906,9 @@ def semanticHash(self) -> int:

.. versionadded:: 3.1.0

.. versionchanged:: 3.4.0
Support Spark Connect.

Notes
-----
Unlike the standard hash code, the hash is calculated against the query plan
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,12 +1987,19 @@ def test_unsupported_functions(self):
"checkpoint",
"localCheckpoint",
"_repr_html_",
"semanticHash",
"sameSemantics",
):
with self.assertRaises(NotImplementedError):
getattr(df, f)()

def test_semantic_hash(self):
df1 = self.connect.read.table(self.tbl_name)
df2 = self.connect.read.table(self.tbl_name)
semantic_hash1 = df1.limit(10).semanticHash()
semantic_hash2 = df2.limit(10).semanticHash()

self.assertEqual(semantic_hash1, semantic_hash2)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ChannelBuilderTests(ReusedPySparkTestCase):
Expand Down