-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-43474] [SS] [CONNECT] Add a spark connect function to create DataFrame reference #41146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ea07c1b
f425deb
80cbe55
4924f91
debacf6
51b7a31
4d0858a
443f175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ message Relation { | |
| ApplyInPandasWithState apply_in_pandas_with_state = 34; | ||
| HtmlString html_string = 35; | ||
| CachedLocalRelation cached_local_relation = 36; | ||
| CachedRemoteRelation cached_remote_relation = 37; | ||
|
|
||
| // NA functions | ||
| NAFill fill_na = 90; | ||
|
|
@@ -395,6 +396,18 @@ message CachedLocalRelation { | |
| string hash = 3; | ||
| } | ||
|
|
||
| // Represents a remote relation that has been cached on server. | ||
| message CachedRemoteRelation { | ||
| // (Required) An identifier of the user which cached the relation | ||
| string userId = 1; | ||
|
|
||
| // (Required) An identifier of the Spark session in which the relation is cached | ||
| string sessionId = 2; | ||
|
Comment on lines
+402
to
+405
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The user, session ID can't be trusted coming from the proto. THe cached relation must only have the actual unique ID of the relation ID and the rest is resolved from the context of the query. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. This is important. Changed the implementation to use SparkSession as the key (it has as |
||
|
|
||
| // (Required) A key represents the id of the cached relation | ||
| string relationId = 3; | ||
| } | ||
|
|
||
| // Relation of type [[Sample]] that samples a fraction of the dataset. | ||
| message Sample { | ||
| // (Required) Input relation for a Sample. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,6 +143,8 @@ class SparkConnectPlanner(val session: SparkSession) { | |
| transformCoGroupMap(rel.getCoGroupMap) | ||
| case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => | ||
| transformApplyInPandasWithState(rel.getApplyInPandasWithState) | ||
| case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => | ||
| transformCachedRemoteRelation(rel.getCachedRemoteRelation) | ||
| case proto.Relation.RelTypeCase.COLLECT_METRICS => | ||
| transformCollectMetrics(rel.getCollectMetrics) | ||
| case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) | ||
|
|
@@ -786,6 +788,12 @@ class SparkConnectPlanner(val session: SparkSession) { | |
| .logicalPlan | ||
| } | ||
|
|
||
| private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = { | ||
| SparkConnectService.cachedDataFrameManager | ||
| .get(rel.getUserId, rel.getSessionId, rel.getRelationId) | ||
| .logicalPlan | ||
| } | ||
|
Comment on lines
+792
to
+795
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conceptually, the cached data should come from the session holder that could be passed to the planner instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. For now proposing to keep it in separate class. Continue discussion here. |
||
|
|
||
| private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = { | ||
| Dataset | ||
| .ofRows(session, transformRelation(rel.getInput)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.connect.service | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.connect.common.InvalidPlanInput | ||
|
|
||
| /** | ||
| * This class caches DataFrame on the server side with given ids. The Spark Connect client can | ||
| * create a DataFrame reference with the id. When server transforms the DataFrame reference, it | ||
| * finds the DataFrame from the cache and replace the reference. | ||
| * | ||
| * Each (userId, sessionId) has a corresponding DataFrame map. A cached DataFrame can only be | ||
| * accessed from the same user within the same session. The DataFrame will be removed from the | ||
| * cache when the session expires. | ||
| */ | ||
| private[connect] class SparkConnectCachedDataFrameManager extends Logging { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do we need Logging? It is not used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. See continuation of this PR : #41580
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add this class to the session holder to make sure that this is properly associated to the right user ID and session. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed above. SessionHolder is not accessible yet. Also removed session_id and user_id from this cache, instead making it key on actual Spark session (user_id & session_id is implicit in that) |
||
|
|
||
| // Each (userId, sessionId) has a DataFrame cache map. | ||
| private val dataFrameCache = mutable.Map[(String, String), mutable.Map[String, DataFrame]]() | ||
|
|
||
| def put(userId: String, sessionId: String, dataFrameId: String, value: DataFrame): Unit = | ||
| synchronized { | ||
| val sessionKey = (userId, sessionId) | ||
| val sessionDataFrameMap = dataFrameCache | ||
| .getOrElseUpdate(sessionKey, mutable.Map[String, DataFrame]()) | ||
| sessionDataFrameMap.put(dataFrameId, value) | ||
| } | ||
|
Comment on lines
+38
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about using two concurrent hash maps + Similar logics apply for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only my personal taste: I feel like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, we could user ConcurrentHashMap. But I often end up preferring
Comment on lines
+40
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will make this easier as well because you only have one concurrent map |
||
|
|
||
| def get(userId: String, sessionId: String, dataFrameId: String): DataFrame = synchronized { | ||
| val sessionKey = (userId, sessionId) | ||
|
|
||
| val notFoundException = InvalidPlanInput( | ||
| s"No DataFrame found in the server cache for key = $dataFrameId in the session $sessionId " + | ||
| s"for the user id $userId.") | ||
|
|
||
| val sessionDataFrameMap = dataFrameCache.getOrElse(sessionKey, throw notFoundException) | ||
| sessionDataFrameMap.getOrElse(dataFrameId, throw notFoundException) | ||
| } | ||
|
|
||
| def remove(userId: String, sessionId: String): Unit = synchronized { | ||
| dataFrameCache.remove((userId, sessionId)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -298,12 +298,15 @@ object SparkConnectService { | |
| userSessionMapping.getIfPresent((userId, sessionId)) | ||
| }) | ||
|
|
||
| private[connect] val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
|
||
| private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] { | ||
| override def onRemoval( | ||
| notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = { | ||
| val SessionHolder(userId, sessionId, session) = notification.getValue | ||
| val blockManager = session.sparkContext.env.blockManager | ||
| blockManager.removeCache(userId, sessionId) | ||
| cachedDataFrameManager.remove(userId, sessionId) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: This reference should be removed from streaming engine once the foreach batch completes.. |
||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.connect.service | ||
|
|
||
| import java.util.UUID | ||
|
|
||
| import org.apache.spark.sql.connect.common.InvalidPlanInput | ||
| import org.apache.spark.sql.test.SharedSparkSession | ||
|
|
||
| class SparkConnectCachedDataFrameManagerSuite extends SharedSparkSession { | ||
|
|
||
| test("Successful put and get") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
|
||
| val userId = UUID.randomUUID().toString | ||
| val sessionId = UUID.randomUUID().toString | ||
|
|
||
| val key1 = "key_1" | ||
| val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
| val df1 = data1.toDF() | ||
| cachedDataFrameManager.put(userId, sessionId, key1, df1) | ||
|
|
||
| val expectedDf1 = cachedDataFrameManager.get(userId, sessionId, key1) | ||
| assert(expectedDf1 == df1) | ||
|
|
||
| val key2 = "key_2" | ||
| val data2 = Seq(("k4", "v4"), ("k5", "v5")) | ||
| val df2 = data2.toDF() | ||
| cachedDataFrameManager.put(userId, sessionId, key2, df2) | ||
|
|
||
| val expectedDf2 = cachedDataFrameManager.get(userId, sessionId, key2) | ||
| assert(expectedDf2 == df2) | ||
| } | ||
|
|
||
| test("Get cache that does not exist should fail") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
|
||
| val userId = UUID.randomUUID().toString | ||
| val sessionId = UUID.randomUUID().toString | ||
| val key1 = "key_1" | ||
|
|
||
| assertThrows[InvalidPlanInput] { | ||
| cachedDataFrameManager.get(userId, sessionId, key1) | ||
| } | ||
|
|
||
| val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
| val df1 = data1.toDF() | ||
| cachedDataFrameManager.put(userId, sessionId, key1, df1) | ||
| cachedDataFrameManager.get(userId, sessionId, key1) | ||
|
|
||
| val key2 = "key_2" | ||
| assertThrows[InvalidPlanInput] { | ||
| cachedDataFrameManager.get(userId, sessionId, key2) | ||
| } | ||
| } | ||
|
|
||
| test("Remove cache and then get should fail") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
|
|
||
| val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
|
||
| val userId = UUID.randomUUID().toString | ||
| val sessionId = UUID.randomUUID().toString | ||
|
|
||
| val key1 = "key_1" | ||
| val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
| val df1 = data1.toDF() | ||
| cachedDataFrameManager.put(userId, sessionId, key1, df1) | ||
| cachedDataFrameManager.get(userId, sessionId, key1) | ||
|
|
||
| cachedDataFrameManager.remove(userId, sessionId) | ||
| assertThrows[InvalidPlanInput] { | ||
| cachedDataFrameManager.get(userId, sessionId, key1) | ||
| } | ||
| } | ||
| } |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,7 @@ | |
| from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder | ||
| from pyspark.sql.connect.conf import RuntimeConf | ||
| from pyspark.sql.connect.dataframe import DataFrame | ||
| from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation | ||
| from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation, CachedRemoteRelation | ||
| from pyspark.sql.connect.readwriter import DataFrameReader | ||
| from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager | ||
| from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer | ||
|
|
@@ -476,6 +476,11 @@ def createDataFrame( | |
|
|
||
| createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ | ||
|
|
||
| def _createCachedDataFrame(self, relationId: str) -> "DataFrame": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems to be unused here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. It will used in foreachBatch implementation (in follow up PRs) |
||
| """This function creates a DataFrame given a id. It expects that on the server side, the | ||
| actual DataFrame is cached with relationId as the key. This function is for internal use.""" | ||
| return DataFrame.withPlan(CachedRemoteRelation(relationId), self) | ||
|
|
||
| def sql(self, sqlQuery: str, args: Optional[Dict[str, Any]] = None) -> "DataFrame": | ||
| cmd = SQL(sqlQuery, args) | ||
| data, properties = self.client.execute_command(cmd.command(self._client)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also just get userId and sessionId from server via request, instead of passing from here.
But that would require we update transformRelation() to take into two more parameters, which means all all those
transform...()need to be updated to have two more parameter.