Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ message Relation {
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
CachedLocalRelation cached_local_relation = 36;
CachedRemoteRelation cached_remote_relation = 37;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -398,6 +399,12 @@ message CachedLocalRelation {
string hash = 3;
}

// Represents a remote relation that has been cached on server.
message CachedRemoteRelation {
// (Required) ID of the remote related (assigned by the service).
string relation_id = 1;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
Expand Down Expand Up @@ -872,6 +874,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
.logicalPlan
}

private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = {
sessionHolder
.getDataFrameOrThrow(rel.getRelationId)
.logicalPlan
}

private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.JobArtifactSet
import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.util.Utils

/**
Expand All @@ -43,6 +46,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
val executePlanOperations: ConcurrentMap[String, ExecutePlanHolder] =
new ConcurrentHashMap[String, ExecutePlanHolder]()

// Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like
// foreachBatch() in Streaming. Lazy since most sessions don't need it.
private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()

private[connect] def createExecutePlanHolder(
request: proto.ExecutePlanRequest): ExecutePlanHolder = {

Expand Down Expand Up @@ -163,6 +170,31 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
}
}
}

/**
* Caches given DataFrame with the ID. The cache does not expire. The entry needs to be
* explicitly removed by the owners of the DataFrame once it is not needed.
*/
private[connect] def cacheDataFrameById(dfId: String, df: DataFrame): Unit = {
if (dataFrameCache.putIfAbsent(dfId, df) != null) {
SparkException.internalError(s"A dataframe is already associated with id $dfId")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internalError just creates SparkException, so, need to throw it apparently. Here is the PR #44400 with a minor fix of this mistake and another one.

}
}

/**
* Returns [[DataFrame]] cached for DataFrame ID `dfId`. If it is not found, throw
* [[InvalidPlanInput]].
*/
private[connect] def getDataFrameOrThrow(dfId: String): DataFrame = {
Option(dataFrameCache.get(dfId))
.getOrElse {
throw InvalidPlanInput(s"No DataFrame with id $dfId is found in the session $sessionId")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here, how about to introduce an error class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be not needed since InvalidPlanInput is used widely for this exact purpose and this is less user visible.

}
}

private[connect] def removeCachedDataFrame(dfId: String): DataFrame = {
dataFrameCache.remove(dfId)
}
}

object SessionHolder {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.test.SharedSparkSession

class SparkConnectSessionHolderSuite extends SharedSparkSession {

test("DataFrame cache: Successful put and get") {
val sessionHolder = SessionHolder.forTesting(spark)
import sessionHolder.session.implicits._

val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
val id1 = "df_id_1"
sessionHolder.cacheDataFrameById(id1, df1)

val expectedDf1 = sessionHolder.getDataFrameOrThrow(id1)
assert(expectedDf1 == df1)

val data2 = Seq(("k4", "v4"), ("k5", "v5"))
val df2 = data2.toDF()
val id2 = "df_id_2"
sessionHolder.cacheDataFrameById(id2, df2)

val expectedDf2 = sessionHolder.getDataFrameOrThrow(id2)
assert(expectedDf2 == df2)
}

test("DataFrame cache: Should throw when dataframe is not found") {
val sessionHolder = SessionHolder.forTesting(spark)
import sessionHolder.session.implicits._

val key1 = "key_1"

assertThrows[InvalidPlanInput] {
sessionHolder.getDataFrameOrThrow(key1)
}

val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
sessionHolder.cacheDataFrameById(key1, df1)
sessionHolder.getDataFrameOrThrow(key1)

val key2 = "key_2"
assertThrows[InvalidPlanInput] {
sessionHolder.getDataFrameOrThrow(key2)
}
}

test("DataFrame cache: Remove cache and then get should fail") {
val sessionHolder = SessionHolder.forTesting(spark)
import sessionHolder.session.implicits._

val key1 = "key_1"
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
sessionHolder.cacheDataFrameById(key1, df1)
sessionHolder.getDataFrameOrThrow(key1)

sessionHolder.removeCachedDataFrame(key1)
assertThrows[InvalidPlanInput] {
sessionHolder.getDataFrameOrThrow(key1)
}
}
}
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,20 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""

def __init__(self, relationId: str):
super().__init__(None)
self._relationId = relationId

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relationId
return plan


class Hint(LogicalPlan):
"""Logical plan object for a Hint operation."""

Expand Down
270 changes: 142 additions & 128 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Relation(google.protobuf.message.Message):
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
HTML_STRING_FIELD_NUMBER: builtins.int
CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
CACHED_REMOTE_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -185,6 +186,8 @@ class Relation(google.protobuf.message.Message):
@property
def cached_local_relation(self) -> global___CachedLocalRelation: ...
@property
def cached_remote_relation(self) -> global___CachedRemoteRelation: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -257,6 +260,7 @@ class Relation(google.protobuf.message.Message):
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
html_string: global___HtmlString | None = ...,
cached_local_relation: global___CachedLocalRelation | None = ...,
cached_remote_relation: global___CachedRemoteRelation | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand All @@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -390,6 +396,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -524,6 +532,7 @@ class Relation(google.protobuf.message.Message):
"apply_in_pandas_with_state",
"html_string",
"cached_local_relation",
"cached_remote_relation",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -1608,6 +1617,25 @@ class CachedLocalRelation(google.protobuf.message.Message):

global___CachedLocalRelation = CachedLocalRelation

class CachedRemoteRelation(google.protobuf.message.Message):
"""Represents a remote relation that has been cached on server."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

RELATION_ID_FIELD_NUMBER: builtins.int
relation_id: builtins.str
"""(Required) ID of the remote related (assigned by the service)."""
def __init__(
self,
*,
relation_id: builtins.str = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["relation_id", b"relation_id"]
) -> None: ...

global___CachedRemoteRelation = CachedRemoteRelation

class Sample(google.protobuf.message.Message):
"""Relation of type [[Sample]] that samples a fraction of the dataset."""

Expand Down