Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2741,8 +2741,23 @@ class Dataset[T] private[sql] (
throw new UnsupportedOperationException("localCheckpoint is not implemented.")
}

/**
* Returns `true` when the logical query plans inside both [[Dataset]]s are equal and therefore
* return same results.
*
* @note
* The equality comparison here is simplified by tolerating the cosmetic differences such as
* attribute names.
* @note
* This API can compare both [[Dataset]]s but can still return `false` on the [[Dataset]] that
* return the same results, for instance, from different plans. Such false negative semantic
* can be useful when caching as an example. This comparison may not be fast because it will
* execute a RPC call.
* @since 3.4.0
*/
@DeveloperApi
def sameSemantics(other: Dataset[T]): Boolean = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, shall we add it into Python API too if we should?

I remember this wasn't added for some concerns from @hvanhovell (maybe I am remembering this wrongly?). This is important API for ML to use in any event. cc @WeichenXu123 FYI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR I have added the python version. Can you take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems @hvanhovell and @cloud-fan had some concerns in sameSemantics and semanticHash

#38742 (comment)

#38742 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the hash argument still stands. However I think this is also a matter of setting the right expectations here, and to update the docs accordingly.

@WeichenXu123 it would be good to understand your usecase.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add @DeveloperApi for this one?

throw new UnsupportedOperationException("sameSemantics is not implemented.")
sparkSession.sameSemantics(this.plan, other.plan)
}

def semanticHash(): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ class SparkSession private[sql] (
client.analyze(method, Some(plan), explainMode)
}

private[sql] def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): Boolean = {
client.sameSemantics(plan, otherPlan).getSameSemantics.getResult
}

private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator, encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ private[sql] class SparkConnectClient(
builder.setSparkVersion(proto.AnalyzePlanRequest.SparkVersion.newBuilder().build())
case other => throw new IllegalArgumentException(s"Unknown Analyze request $other")
}
analyze(builder)
}

def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): proto.AnalyzePlanResponse = {
val builder = proto.AnalyzePlanRequest.newBuilder()
builder.setSameSemantics(
proto.AnalyzePlanRequest.SameSemantics
.newBuilder()
.setTargetPlan(plan)
.setOtherPlan(otherPlan))
analyze(builder)
}

private def analyze(builder: proto.AnalyzePlanRequest.Builder): proto.AnalyzePlanResponse = {
val request = builder
.setUserContext(userContext)
.setClientId(sessionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,12 @@ class ClientE2ETestSuite extends RemoteSparkSession {
val result = spark.createDataFrame(data.asJava, schema).collect()
assert(result === data)
}

test("SameSemantics") {
val plan = spark.sql("select 1")
val otherPlan = spark.sql("select 1")
assert(plan.sameSemantics(otherPlan))
}
}

private[sql] case class MyType(id: Long, a: Double, b: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ message AnalyzePlanRequest {
InputFiles input_files = 9;
SparkVersion spark_version = 10;
DDLParse ddl_parse = 11;
SameSemantics same_semantics = 12;
}

message Schema {
Expand Down Expand Up @@ -145,6 +146,16 @@ message AnalyzePlanRequest {
// (Required) The DDL formatted string to be parsed.
string ddl_string = 1;
}


// Returns `true` when the logical query plans are equal and therefore return same results.
message SameSemantics {
// (Required) The plan to be compared.
Plan target_plan = 1;

// (Required) The other plan to be compared.
Plan other_plan = 2;
}
}

// Response to performing analysis of the query. Contains relevant metadata to be able to
Expand All @@ -161,6 +172,7 @@ message AnalyzePlanResponse {
InputFiles input_files = 7;
SparkVersion spark_version = 8;
DDLParse ddl_parse = 9;
SameSemantics same_semantics = 10;
}

message Schema {
Expand Down Expand Up @@ -195,6 +207,10 @@ message AnalyzePlanResponse {
message DDLParse {
DataType parsed = 1;
}

message SameSemantics {
bool result = 1;
}
}

// A request to be executed by the service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ private[connect] class SparkConnectAnalyzeHandler(
.setParsed(DataTypeProtoConverter.toConnectProtoType(schema))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
planner.transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
planner.transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!")
}

Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(
input_files: Optional[List[str]],
spark_version: Optional[str],
parsed: Optional[pb2.DataType],
is_same_semantics: Optional[bool],
):
self.schema = schema
self.explain_string = explain_string
Expand All @@ -410,6 +411,7 @@ def __init__(
self.input_files = input_files
self.spark_version = spark_version
self.parsed = parsed
self.is_same_semantics = is_same_semantics

@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
Expand All @@ -421,6 +423,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
input_files: Optional[List[str]] = None
spark_version: Optional[str] = None
parsed: Optional[pb2.DataType] = None
is_same_semantics: Optional[bool] = None

if pb.HasField("schema"):
schema = pb.schema.schema
Expand All @@ -438,6 +441,8 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
spark_version = pb.spark_version.version
elif pb.HasField("ddl_parse"):
parsed = pb.ddl_parse.parsed
elif pb.HasField("same_semantics"):
is_same_semantics = pb.same_semantics.result
else:
raise SparkConnectException("No analyze result found!")

Expand All @@ -450,6 +455,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
input_files,
spark_version,
parsed,
is_same_semantics,
)


Expand Down Expand Up @@ -690,6 +696,14 @@ def execute_command(
else:
return (None, properties)

def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
"""
return if two plans have the same semantics.
"""
result = self._analyze(method="same_semantics", plan=plan, other=other).is_same_semantics
assert result is not None
return result

def close(self) -> None:
"""
Close the channel.
Expand Down Expand Up @@ -765,6 +779,9 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
req.spark_version.SetInParent()
elif method == "ddl_parse":
req.ddl_parse.ddl_string = cast(str, kwargs.get("ddl_string"))
elif method == "same_semantics":
req.same_semantics.target_plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
req.same_semantics.other_plan.CopyFrom(cast(pb2.Plan, kwargs.get("other")))
else:
raise ValueError(f"Unknown Analyze method: {method}")

Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,8 +1610,15 @@ def _repr_html_(self, *args: Any, **kwargs: Any) -> None:
def semanticHash(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("semanticHash() is not implemented.")

def sameSemantics(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("sameSemantics() is not implemented.")
def sameSemantics(self, other: "DataFrame") -> bool:
assert self._plan is not None
assert other._plan is not None
return self._session.client.same_semantics(
plan=self._plan.to_proto(self._session.client),
other=other._plan.to_proto(other._session.client),
)

sameSemantics.__doc__ = PySparkDataFrame.sameSemantics.__doc__

def writeTo(self, table: str) -> "DataFrameWriterV2":
assert self._plan is not None
Expand Down
214 changes: 120 additions & 94 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

64 changes: 64 additions & 0 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,38 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["ddl_string", b"ddl_string"]
) -> None: ...

class SameSemantics(google.protobuf.message.Message):
"""Returns `true` when the logical query plans are equal and therefore return same results."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

TARGET_PLAN_FIELD_NUMBER: builtins.int
OTHER_PLAN_FIELD_NUMBER: builtins.int
@property
def target_plan(self) -> global___Plan:
"""(Required) The plan to be compared."""
@property
def other_plan(self) -> global___Plan:
"""(Required) The other plan to be compared."""
def __init__(
self,
*,
target_plan: global___Plan | None = ...,
other_plan: global___Plan | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"other_plan", b"other_plan", "target_plan", b"target_plan"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"other_plan", b"other_plan", "target_plan", b"target_plan"
],
) -> None: ...

CLIENT_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
Expand All @@ -329,6 +361,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
INPUT_FILES_FIELD_NUMBER: builtins.int
SPARK_VERSION_FIELD_NUMBER: builtins.int
DDL_PARSE_FIELD_NUMBER: builtins.int
SAME_SEMANTICS_FIELD_NUMBER: builtins.int
client_id: builtins.str
"""(Required)

Expand Down Expand Up @@ -359,6 +392,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
def spark_version(self) -> global___AnalyzePlanRequest.SparkVersion: ...
@property
def ddl_parse(self) -> global___AnalyzePlanRequest.DDLParse: ...
@property
def same_semantics(self) -> global___AnalyzePlanRequest.SameSemantics: ...
def __init__(
self,
*,
Expand All @@ -373,6 +408,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
input_files: global___AnalyzePlanRequest.InputFiles | None = ...,
spark_version: global___AnalyzePlanRequest.SparkVersion | None = ...,
ddl_parse: global___AnalyzePlanRequest.DDLParse | None = ...,
same_semantics: global___AnalyzePlanRequest.SameSemantics | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -393,6 +429,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
b"is_local",
"is_streaming",
b"is_streaming",
"same_semantics",
b"same_semantics",
"schema",
b"schema",
"spark_version",
Expand Down Expand Up @@ -424,6 +462,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
b"is_local",
"is_streaming",
b"is_streaming",
"same_semantics",
b"same_semantics",
"schema",
b"schema",
"spark_version",
Expand All @@ -450,6 +490,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
"input_files",
"spark_version",
"ddl_parse",
"same_semantics",
] | None: ...

global___AnalyzePlanRequest = AnalyzePlanRequest
Expand Down Expand Up @@ -583,6 +624,20 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["parsed", b"parsed"]
) -> None: ...

class SameSemantics(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

RESULT_FIELD_NUMBER: builtins.int
result: builtins.bool
def __init__(
self,
*,
result: builtins.bool = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["result", b"result"]
) -> None: ...

CLIENT_ID_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
EXPLAIN_FIELD_NUMBER: builtins.int
Expand All @@ -592,6 +647,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
INPUT_FILES_FIELD_NUMBER: builtins.int
SPARK_VERSION_FIELD_NUMBER: builtins.int
DDL_PARSE_FIELD_NUMBER: builtins.int
SAME_SEMANTICS_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
def schema(self) -> global___AnalyzePlanResponse.Schema: ...
Expand All @@ -609,6 +665,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
def spark_version(self) -> global___AnalyzePlanResponse.SparkVersion: ...
@property
def ddl_parse(self) -> global___AnalyzePlanResponse.DDLParse: ...
@property
def same_semantics(self) -> global___AnalyzePlanResponse.SameSemantics: ...
def __init__(
self,
*,
Expand All @@ -621,6 +679,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
input_files: global___AnalyzePlanResponse.InputFiles | None = ...,
spark_version: global___AnalyzePlanResponse.SparkVersion | None = ...,
ddl_parse: global___AnalyzePlanResponse.DDLParse | None = ...,
same_semantics: global___AnalyzePlanResponse.SameSemantics | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -637,6 +696,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
b"is_streaming",
"result",
b"result",
"same_semantics",
b"same_semantics",
"schema",
b"schema",
"spark_version",
Expand All @@ -662,6 +723,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
b"is_streaming",
"result",
b"result",
"same_semantics",
b"same_semantics",
"schema",
b"schema",
"spark_version",
Expand All @@ -681,6 +744,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
"input_files",
"spark_version",
"ddl_parse",
"same_semantics",
] | None: ...

global___AnalyzePlanResponse = AnalyzePlanResponse
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,6 +2809,11 @@ def test_version(self):
self.spark.version,
)

def test_same_semantics(self):
plan = self.connect.sql("SELECT 1")
other = self.connect.sql("SELECT 1")
self.assertTrue(plan.sameSemantics(other))

def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
df = self.connect.read.table(self.tbl_name)
Expand All @@ -2825,7 +2830,6 @@ def test_unsupported_functions(self):
"localCheckpoint",
"_repr_html_",
"semanticHash",
"sameSemantics",
):
with self.assertRaises(NotImplementedError):
getattr(df, f)()
Expand Down
Loading