Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,115 @@ def execute_command_as_iterator(
},
)

def batch_execute(
self,
plan_sequences: List[List[Tuple[pb2.Plan, Optional[str]]]],
sequence_operation_ids: Optional[List[Optional[str]]] = None,
) -> pb2.BatchExecutePlanResponse:
"""
Execute multiple sequences of plans in batch.

Each sequence executes sequentially, all sequences execute in parallel.
Single-plan batches are treated as sequences containing one plan.

Parameters
----------
plan_sequences : list of list of (Plan, Optional[str])
List of sequences. Each sequence is a list of (plan, operation_id) tuples.
sequence_operation_ids : list of Optional[str], optional
Optional operation IDs for each sequence

Returns
-------
BatchExecutePlanResponse
Response containing sequence operation IDs and query operation IDs
"""
import uuid

req = pb2.BatchExecutePlanRequest()
req.session_id = self._session_id
if self._user_id:
req.user_context.user_id = self._user_id
req.client_type = self._builder.userAgent

if self._server_side_session_id:
req.client_observed_server_side_session_id = self._server_side_session_id

seq_op_ids = sequence_operation_ids or []

for idx, sequence in enumerate(plan_sequences):
seq = req.plan_sequences.add()

# Set sequence operation ID if provided
if idx < len(seq_op_ids) and seq_op_ids[idx]:
try:
uuid.UUID(seq_op_ids[idx])
seq.sequence_operation_id = seq_op_ids[idx]
except ValueError:
raise PySparkValueError(
error_class="INVALID_HANDLE.FORMAT",
message_parameters={"handle": seq_op_ids[idx]},
)

# Add plans to sequence
for plan, op_id in sequence:
plan_exec = seq.plan_executions.add()
plan_exec.plan.CopyFrom(plan)
if op_id:
try:
uuid.UUID(op_id)
plan_exec.operation_id = op_id
except ValueError:
raise PySparkValueError(
error_class="INVALID_HANDLE.FORMAT",
message_parameters={"handle": op_id},
)

metadata = self._builder.metadata()
return self._stub.BatchExecutePlan(req, metadata=metadata)

def reattach_execute(self, operation_id: str) -> Iterator[pb2.ExecutePlanResponse]:
"""
Reattach to an existing operation by operation ID and consume all responses.

Parameters
----------
operation_id : str
The operation ID to reattach to (must be a valid UUID)

Returns
-------
Iterator[pb2.ExecutePlanResponse]
An iterator of ExecutePlanResponse messages

Raises
------
ValueError
If the operation_id is not a valid UUID format
"""
import uuid

try:
uuid.UUID(operation_id)
except ValueError:
raise PySparkValueError(
error_class="INVALID_HANDLE.FORMAT",
message_parameters={"handle": operation_id},
)

req = pb2.ReattachExecuteRequest()
req.session_id = self._session_id
if self._user_id:
req.user_context.user_id = self._user_id
req.operation_id = operation_id
req.client_type = self._builder.userAgent

if self._server_side_session_id:
req.client_observed_server_side_session_id = self._server_side_session_id

metadata = self._builder.metadata()
return self._stub.ReattachExecute(req, metadata=metadata)

def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
"""
return if two plans have the same semantics.
Expand Down
248 changes: 130 additions & 118 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

Loading