Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ message Relation {
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
CachedLocalRelation cached_local_relation = 36;
CachedRemoteRelation cached_remote_relation = 37;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -395,6 +396,18 @@ message CachedLocalRelation {
string hash = 3;
}

// Represents a remote relation that has been cached on server.
message CachedRemoteRelation {
// (Required) An identifier of the user which cached the relation
string userId = 1;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also just get userId and sessionId from server via request, instead of passing from here.
But that would require we update transformRelation() to take into two more parameters, which means all all those transform...() need to be updated to have two more parameter.


// (Required) An identifier of the Spark session in which the relation is cached
string sessionId = 2;
Comment on lines +402 to +405
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user, session ID can't be trusted coming from the proto. THe cached relation must only have the actual unique ID of the relation ID and the rest is resolved from the context of the query.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. This is important. Changed the implementation to use SparkSession as the key (it has as sessionUUID)
[continue the discussion here]


// (Required) A key represents the id of the cached relation
string relationId = 3;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
Expand Down Expand Up @@ -786,6 +788,12 @@ class SparkConnectPlanner(val session: SparkSession) {
.logicalPlan
}

private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = {
SparkConnectService.cachedDataFrameManager
.get(rel.getUserId, rel.getSessionId, rel.getRelationId)
.logicalPlan
}
Comment on lines +792 to +795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, the cached data should come from the session holder that could be passed to the planner instead.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. For now proposing to keep it in separate class. Continue discussion here.


private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.common.InvalidPlanInput

/**
* This class caches DataFrame on the server side with given ids. The Spark Connect client can
* create a DataFrame reference with the id. When server transforms the DataFrame reference, it
* finds the DataFrame from the cache and replace the reference.
*
* Each (userId, sessionId) has a corresponding DataFrame map. A cached DataFrame can only be
* accessed from the same user within the same session. The DataFrame will be removed from the
* cache when the session expires.
*/
private[connect] class SparkConnectCachedDataFrameManager extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need Logging? It is not used?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. See continuation of this PR : #41580

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this class to the session holder to make sure that this is properly associated to the right user ID and session.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed above. SessionHolder is not accessible yet. Also removed session_id and user_id from this cache, instead making it key on actual Spark session (user_id & session_id is implicit in that)


// Each (userId, sessionId) has a DataFrame cache map.
private val dataFrameCache = mutable.Map[(String, String), mutable.Map[String, DataFrame]]()

def put(userId: String, sessionId: String, dataFrameId: String, value: DataFrame): Unit =
synchronized {
val sessionKey = (userId, sessionId)
val sessionDataFrameMap = dataFrameCache
.getOrElseUpdate(sessionKey, mutable.Map[String, DataFrame]())
sessionDataFrameMap.put(dataFrameId, value)
}
Comment on lines +38 to +46
Copy link
Contributor

@zhenlineo zhenlineo Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using two concurrent hash maps + compute to avoid synchronized? For example:

    dataFrameCache.compute(sessionKey, (key, sessionDataFrameMap) => {
      val newMap = if (sessionDataFrameMap == null) new ConcurrentHashMap[String, String]() else sessionDataFrameMap
      newMap.put(dataFrameId, value)
    })

Similar logics apply for remove.
For get, you just need to get without the need to explicitly lock.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only my personal taste:

I feel like synchronized is easy to reason compared to ConcurrentHashMap for code readers. Unless there is significantly performance gain somehow if we switch to a concurrent data structure.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we could user ConcurrentHashMap. But I often end up preferring synchronized as well. Since this is not perf critical (used only for certain DFs), though I am not sure if there is any perf difference.
Added @GuardedBy annotation.
See the the continuation of this PR here: https://github.com/apache/spark/pull/41580/files#diff-1a8933e9723f5497c3991441c7ff21fe43db63d483354af9a0113043ea600b3eR42

Comment on lines +40 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make this easier as well because you only have one concurrent map


def get(userId: String, sessionId: String, dataFrameId: String): DataFrame = synchronized {
val sessionKey = (userId, sessionId)

val notFoundException = InvalidPlanInput(
s"No DataFrame found in the server cache for key = $dataFrameId in the session $sessionId " +
s"for the user id $userId.")

val sessionDataFrameMap = dataFrameCache.getOrElse(sessionKey, throw notFoundException)
sessionDataFrameMap.getOrElse(dataFrameId, throw notFoundException)
}

def remove(userId: String, sessionId: String): Unit = synchronized {
dataFrameCache.remove((userId, sessionId))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,15 @@ object SparkConnectService {
userSessionMapping.getIfPresent((userId, sessionId))
})

private[connect] val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] {
override def onRemoval(
notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = {
val SessionHolder(userId, sessionId, session) = notification.getValue
val blockManager = session.sparkContext.env.blockManager
blockManager.removeCache(userId, sessionId)
cachedDataFrameManager.remove(userId, sessionId)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: This reference should be removed from streaming engine once the foreach batch completes..

}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import java.util.UUID

import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.test.SharedSparkSession

class SparkConnectCachedDataFrameManagerSuite extends SharedSparkSession {

test("Successful put and get") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val userId = UUID.randomUUID().toString
val sessionId = UUID.randomUUID().toString

val key1 = "key_1"
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(userId, sessionId, key1, df1)

val expectedDf1 = cachedDataFrameManager.get(userId, sessionId, key1)
assert(expectedDf1 == df1)

val key2 = "key_2"
val data2 = Seq(("k4", "v4"), ("k5", "v5"))
val df2 = data2.toDF()
cachedDataFrameManager.put(userId, sessionId, key2, df2)

val expectedDf2 = cachedDataFrameManager.get(userId, sessionId, key2)
assert(expectedDf2 == df2)
}

test("Get cache that does not exist should fail") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val userId = UUID.randomUUID().toString
val sessionId = UUID.randomUUID().toString
val key1 = "key_1"

assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(userId, sessionId, key1)
}

val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(userId, sessionId, key1, df1)
cachedDataFrameManager.get(userId, sessionId, key1)

val key2 = "key_2"
assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(userId, sessionId, key2)
}
}

test("Remove cache and then get should fail") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val userId = UUID.randomUUID().toString
val sessionId = UUID.randomUUID().toString

val key1 = "key_1"
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(userId, sessionId, key1, df1)
cachedDataFrameManager.get(userId, sessionId, key1)

cachedDataFrameManager.remove(userId, sessionId)
assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(userId, sessionId, key1)
}
}
}
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,24 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""

def __init__(self, relationId: str):
super().__init__(None)
self._relationId = relationId

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relationId = self._relationId
if session._user_id:
plan.cached_remote_relation.userId = session._user_id
plan.cached_remote_relation.sessionId = session._session_id

return plan


class Hint(LogicalPlan):
"""Logical plan object for a Hint operation."""

Expand Down
270 changes: 142 additions & 128 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Relation(google.protobuf.message.Message):
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
HTML_STRING_FIELD_NUMBER: builtins.int
CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
CACHED_REMOTE_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -185,6 +186,8 @@ class Relation(google.protobuf.message.Message):
@property
def cached_local_relation(self) -> global___CachedLocalRelation: ...
@property
def cached_remote_relation(self) -> global___CachedRemoteRelation: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -257,6 +260,7 @@ class Relation(google.protobuf.message.Message):
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
html_string: global___HtmlString | None = ...,
cached_local_relation: global___CachedLocalRelation | None = ...,
cached_remote_relation: global___CachedRemoteRelation | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand All @@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -390,6 +396,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -524,6 +532,7 @@ class Relation(google.protobuf.message.Message):
"apply_in_pandas_with_state",
"html_string",
"cached_local_relation",
"cached_remote_relation",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -1593,6 +1602,36 @@ class CachedLocalRelation(google.protobuf.message.Message):

global___CachedLocalRelation = CachedLocalRelation

class CachedRemoteRelation(google.protobuf.message.Message):
"""Represents a remote relation that has been cached on server."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

USERID_FIELD_NUMBER: builtins.int
SESSIONID_FIELD_NUMBER: builtins.int
RELATIONID_FIELD_NUMBER: builtins.int
userId: builtins.str
"""(Required) An identifier of the user which cached the relation"""
sessionId: builtins.str
"""(Required) An identifier of the Spark session in which the relation is cached"""
relationId: builtins.str
"""(Required) A key represents the id of the cached relation"""
def __init__(
self,
*,
userId: builtins.str = ...,
sessionId: builtins.str = ...,
relationId: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"relationId", b"relationId", "sessionId", b"sessionId", "userId", b"userId"
],
) -> None: ...

global___CachedRemoteRelation = CachedRemoteRelation

class Sample(google.protobuf.message.Message):
"""Relation of type [[Sample]] that samples a fraction of the dataset."""

Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation, CachedRemoteRelation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
Expand Down Expand Up @@ -476,6 +476,11 @@ def createDataFrame(

createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__

def _createCachedDataFrame(self, relationId: str) -> "DataFrame":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be unused here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. It will used in foreachBatch implementation (in follow up PRs)

"""This function creates a DataFrame given a id. It expects that on the server side, the
actual DataFrame is cached with relationId as the key. This function is for internal use."""
return DataFrame.withPlan(CachedRemoteRelation(relationId), self)

def sql(self, sqlQuery: str, args: Optional[Dict[str, Any]] = None) -> "DataFrame":
cmd = SQL(sqlQuery, args)
data, properties = self.client.execute_command(cmd.command(self._client))
Expand Down