Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,29 @@
],
"sqlState" : "22P03"
},
"INVALID_CLONE_SESSION_REQUEST" : {
"message" : [
"Invalid session clone request."
],
"subClass" : {
"TARGET_SESSION_ID_ALREADY_CLOSED" : {
"message" : [
"Cannot clone session to target session ID <targetSessionId> because a session with this ID was previously closed."
]
},
"TARGET_SESSION_ID_ALREADY_EXISTS" : {
"message" : [
"Cannot clone session to target session ID <targetSessionId> because a session with this ID already exists."
]
},
"TARGET_SESSION_ID_FORMAT" : {
"message" : [
"Target session ID <targetSessionId> for clone operation must be an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a different error for this as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but this makes it clear that the format of the target session is invalid, so the exception sent to the server isn't confusing

]
}
},
"sqlState" : "42K04"
},
"INVALID_COLUMN_NAME_AS_PATH" : {
"message" : [
"The datasource <datasource> cannot save the column <columnName> because its name contains some characters that are not allowed in file paths. Please, use an alias to rename it."
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_basic",
"pyspark.sql.tests.connect.test_connect_dataframe_property",
"pyspark.sql.tests.connect.test_connect_channel",
"pyspark.sql.tests.connect.test_connect_clone_session",
"pyspark.sql.tests.connect.test_connect_error",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_collection",
Expand Down
66 changes: 66 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2105,3 +2105,69 @@ def _query_model_size(self, model_ref_id: str) -> int:

ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long

def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient":
"""
Clone this client session on the server side. The server-side session is cloned with
all its current state (SQL configurations, temporary views, registered functions,
catalog state) copied over to a new independent session. The returned client with the
cloned session is isolated from this client's session - any subsequent changes to
either session's server-side state will not be reflected in the other.

Parameters
----------
new_session_id : str, optional
Custom session ID to use for the cloned session (must be a valid UUID).
If not provided, a new UUID will be generated.

Returns
-------
SparkConnectClient
A new SparkConnectClient instance with the cloned session.

Notes
-----
This creates a new server-side session with the specified or generated session ID
while preserving the current session's configuration and state.

.. note::
This is a developer API.
"""
from pyspark.sql.connect.proto import base_pb2 as pb2

request = pb2.CloneSessionRequest(
session_id=self._session_id,
client_type="python",
)
if self._user_id is not None:
request.user_context.user_id = self._user_id

if new_session_id is not None:
request.new_session_id = new_session_id

for attempt in self._retrying():
with attempt:
response: pb2.CloneSessionResponse = self._stub.CloneSession(
request, metadata=self._builder.metadata()
)

# Assert that the returned session ID matches the requested ID if one was provided
if new_session_id is not None:
assert response.new_session_id == new_session_id, (
f"Returned session ID '{response.new_session_id}' does not match "
f"requested ID '{new_session_id}'"
)

# Create a new client with the cloned session ID
new_connection = copy.deepcopy(self._builder)
new_connection.set(ChannelBuilder.PARAM_SESSION_ID, response.new_session_id)

# Create new client and explicitly set the session ID
new_client = SparkConnectClient(
connection=new_connection,
user_id=self._user_id,
use_reattachable_execute=self._use_reattachable_execute,
)
# Ensure the session ID is correctly set from the response
new_client._session_id = response.new_session_id
return new_client
10 changes: 7 additions & 3 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

149 changes: 149 additions & 0 deletions python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4081,3 +4081,152 @@ class CheckpointCommandResult(google.protobuf.message.Message):
) -> None: ...

global___CheckpointCommandResult = CheckpointCommandResult

class CloneSessionRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

SESSION_ID_FIELD_NUMBER: builtins.int
CLIENT_OBSERVED_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
NEW_SESSION_ID_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""(Required)

The session_id specifies a spark session for a user id (which is specified
by user_context.user_id). The session_id is set by the client to be able to
collate streaming responses from different queries within the dedicated session.
The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
"""
client_observed_server_side_session_id: builtins.str
"""(Optional)

Server-side generated idempotency key from the previous responses (if any). Server
can use this to validate that the server side session has not changed.
"""
@property
def user_context(self) -> global___UserContext:
"""(Required) User context

user_context.user_id and session_id both identify a unique remote spark session on the
server side.
"""
client_type: builtins.str
"""Provides optional information about the client sending the request. This field
can be used for language or version specific information and is only intended for
logging purposes and will not be interpreted by the server.
"""
new_session_id: builtins.str
"""(Optional)
The session_id for the new cloned session. If not provided, a new UUID will be generated.
The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
"""
def __init__(
self,
*,
session_id: builtins.str = ...,
client_observed_server_side_session_id: builtins.str | None = ...,
user_context: global___UserContext | None = ...,
client_type: builtins.str | None = ...,
new_session_id: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_client_observed_server_side_session_id",
b"_client_observed_server_side_session_id",
"_client_type",
b"_client_type",
"_new_session_id",
b"_new_session_id",
"client_observed_server_side_session_id",
b"client_observed_server_side_session_id",
"client_type",
b"client_type",
"new_session_id",
b"new_session_id",
"user_context",
b"user_context",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_client_observed_server_side_session_id",
b"_client_observed_server_side_session_id",
"_client_type",
b"_client_type",
"_new_session_id",
b"_new_session_id",
"client_observed_server_side_session_id",
b"client_observed_server_side_session_id",
"client_type",
b"client_type",
"new_session_id",
b"new_session_id",
"session_id",
b"session_id",
"user_context",
b"user_context",
],
) -> None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_client_observed_server_side_session_id", b"_client_observed_server_side_session_id"
],
) -> typing_extensions.Literal["client_observed_server_side_session_id"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"]
) -> typing_extensions.Literal["client_type"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_new_session_id", b"_new_session_id"]
) -> typing_extensions.Literal["new_session_id"] | None: ...

global___CloneSessionRequest = CloneSessionRequest

class CloneSessionResponse(google.protobuf.message.Message):
"""Next ID: 5"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

SESSION_ID_FIELD_NUMBER: builtins.int
SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
NEW_SESSION_ID_FIELD_NUMBER: builtins.int
NEW_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""Session id of the original session that was cloned."""
server_side_session_id: builtins.str
"""Server-side generated idempotency key that the client can use to assert that the server side
session (parent of the cloned session) has not changed.
"""
new_session_id: builtins.str
"""Session id of the new cloned session."""
new_server_side_session_id: builtins.str
"""Server-side session ID of the new cloned session."""
def __init__(
self,
*,
session_id: builtins.str = ...,
server_side_session_id: builtins.str = ...,
new_session_id: builtins.str = ...,
new_server_side_session_id: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"new_server_side_session_id",
b"new_server_side_session_id",
"new_session_id",
b"new_session_id",
"server_side_session_id",
b"server_side_session_id",
"session_id",
b"session_id",
],
) -> None: ...

global___CloneSessionResponse = CloneSessionResponse
55 changes: 55 additions & 0 deletions python/pyspark/sql/connect/proto/base_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def __init__(self, channel):
response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
_registered_method=True,
)
self.CloneSession = channel.unary_unary(
"/spark.connect.SparkConnectService/CloneSession",
request_serializer=spark_dot_connect_dot_base__pb2.CloneSessionRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.CloneSessionResponse.FromString,
_registered_method=True,
)


class SparkConnectServiceServicer(object):
Expand Down Expand Up @@ -172,6 +178,20 @@ def FetchErrorDetails(self, request, context):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def CloneSession(self, request, context):
"""Create a clone of a Spark Connect session on the server side. The server-side session
is cloned with all its current state (SQL configurations, temporary views, registered
functions, catalog state) copied over to a new independent session. The cloned session
is isolated from the source session - any subsequent changes to either session's
server-side state will not be reflected in the other.

The request can optionally specify a custom session ID for the cloned session (must be
a valid UUID). If not provided, a new UUID will be generated automatically.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")


def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -225,6 +245,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server):
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.SerializeToString,
),
"CloneSession": grpc.unary_unary_rpc_method_handler(
servicer.CloneSession,
request_deserializer=spark_dot_connect_dot_base__pb2.CloneSessionRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.CloneSessionResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
Expand Down Expand Up @@ -536,3 +561,33 @@ def FetchErrorDetails(
metadata,
_registered_method=True,
)

@staticmethod
def CloneSession(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_unary(
request,
target,
"/spark.connect.SparkConnectService/CloneSession",
spark_dot_connect_dot_base__pb2.CloneSessionRequest.SerializeToString,
spark_dot_connect_dot_base__pb2.CloneSessionResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True,
)
34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,40 @@ def _parse_ddl(self, ddl: str) -> DataType:
assert dt is not None
return dt

def cloneSession(self, new_session_id: Optional[str] = None) -> "SparkSession":
"""
Create a clone of this Spark Connect session on the server side. The server-side session
is cloned with all its current state (SQL configurations, temporary views, registered
functions, catalog state) copied over to a new independent session. The returned cloned
session is isolated from this session - any subsequent changes to either session's
server-side state will not be reflected in the other.

Parameters
----------
new_session_id : str, optional
Custom session ID to use for the cloned session (must be a valid UUID).
If not provided, a new UUID will be generated.

Returns
-------
SparkSession
A new SparkSession instance with the cloned session.

Notes
-----
This creates a new server-side session with the specified or generated session ID
while preserving the current session's configuration and state.

.. note::
This is a developer API.
"""
cloned_client = self._client.clone(new_session_id)
# Create a new SparkSession with the cloned client directly
new_session = object.__new__(SparkSession)
new_session._client = cloned_client
new_session._session_id = cloned_client._session_id
return new_session


SparkSession.__doc__ = PySparkSession.__doc__

Expand Down
Loading