Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "Dat
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

session = SparkSession.active()
return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
plan=plan.CachedRemoteRelation(relation.relation_id, session),
session=session,
)
38 changes: 0 additions & 38 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

# mypy: disable-error-code="override"
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.errors.exceptions.base import (
SessionNotSameException,
PySparkIndexError,
Expand Down Expand Up @@ -138,41 +137,6 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._cached_remote_relation_id: Optional[str] = None

def __del__(self) -> None:
# If session is already closed, all cached DataFrame should be released.
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
try:
command = plan.RemoveRemoteCachedRelation(
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
).command(session=self._session.client)
req = self._session.client._execute_plan_request_with_metadata()
if self._session.client._user_id:
req.user_context.user_id = self._session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in self._session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = self._session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = self._session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -2137,7 +2101,6 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
Expand All @@ -2146,7 +2109,6 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame":
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

if not is_remote_only():
Expand Down
54 changes: 46 additions & 8 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import pickle
from threading import Lock
from inspect import signature, isclass
import warnings

import pyarrow as pa

Expand All @@ -49,6 +50,7 @@

import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.expressions import Expression
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
Expand All @@ -62,6 +64,7 @@
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.session import SparkSession


class LogicalPlan:
Expand Down Expand Up @@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""

def __init__(self, relationId: str):
def __init__(self, relation_id: str, spark_session: "SparkSession"):
super().__init__(None)
self._relationId = relationId

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relationId
return plan
self._relation_id = relation_id
# Needs to hold the session to make a request itself.
self._spark_session = spark_session

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation_id
return plan

def __del__(self) -> None:
session = self._spark_session
# If session is already closed, all cached DataFrame should be released.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to release those cached dataframes in server side?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.. we can only tell when to release at the client side

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the change, we can only know if the session is disconnected, and we're already releasing all in this case.

if session is not None and not session.client.is_closed and self._relation_id is not None:
try:
command = RemoveRemoteCachedRelation(self).command(session=session.client)
req = session.client._execute_plan_request_with_metadata()
if session.client._user_id:
req.user_context.user_id = session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")


class Hint(LogicalPlan):
Expand Down Expand Up @@ -1792,7 +1830,7 @@ def __init__(self, relation: CachedRemoteRelation) -> None:

def command(self, session: "SparkConnectClient") -> proto.Command:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation._relationId
plan.cached_remote_relation.relation_id = self._relation._relation_id
cmd = proto.Command()
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
return cmd
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame":
This is used in ForeachBatch() runner, where the remote DataFrame refers to the
output of a micro batch.
"""
return DataFrame(CachedRemoteRelation(remote_id), self)
return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self)

@staticmethod
def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
Expand Down
53 changes: 47 additions & 6 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#

import os
import gc
import unittest
import shutil
import tempfile
import time

from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
Expand All @@ -34,6 +34,7 @@
ArrayType,
Row,
)
from pyspark.testing.utils import eventually
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.connectutils import (
should_test_connect,
Expand Down Expand Up @@ -1379,8 +1380,8 @@ def test_garbage_collection_checkpoint(self):
# SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state
# in Spark Connect server
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._cached_remote_relation_id)
cached_remote_relation_id = df._cached_remote_relation_id
self.assertIsNotNone(df._plan._relation_id)
cached_remote_relation_id = df._plan._relation_id

jvm = self.spark._jvm
session_holder = getattr(
Expand All @@ -1397,14 +1398,54 @@ def test_garbage_collection_checkpoint(self):
)

del df
gc.collect()

time.sleep(3) # Make sure removing is triggered, and executed in the server.
def condition():
# Check the state was removed up on garbage-collection.
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

eventually(catch_assertions=True)(condition)()

def test_garbage_collection_derived_checkpoint(self):
# SPARK-48258: Should keep the cached remote relation when derived DataFrames exist
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._plan._relation_id)
derived = df.repartition(10)
cached_remote_relation_id = df._plan._relation_id

# Check the state was removed up on garbage-collection.
self.assertIsNone(
jvm = self.spark._jvm
session_holder = getattr(
getattr(
jvm.org.apache.spark.sql.connect.service,
"SparkConnectService$",
),
"MODULE$",
).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id)

# Check the state exists.
self.assertIsNotNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

del df
gc.collect()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike JVM, this does trigger the full GC


def condition():
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

# Should not remove the cache
with self.assertRaises(AssertionError):
eventually(catch_assertions=True, timeout=5)(condition)()

del derived
gc.collect()

eventually(catch_assertions=True)(condition)()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down