Skip to content
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2971,7 +2971,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def unpersist(blocking: Boolean): this.type = {
sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking)
sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also update the comment of line 2966 and line 2979 and explain the new behavior

this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,58 @@ class CacheManager extends Logging {
}

/**
* Un-cache all the cache entries that refer to the given plan.
* Un-cache the given plan or all the cache entries that refer to the given plan.
* @param query The [[Dataset]] to be un-cached.
* @param cascade If true, un-cache all the cache entries that refer to the given
* [[Dataset]]; otherwise un-cache the given [[Dataset]] only.
* @param blocking Whether to block until all blocks are deleted.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
def uncacheQuery(query: Dataset[_],
cascade: Boolean, blocking: Boolean = true): Unit = writeLock {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking)
}

/**
* Un-cache all the cache entries that refer to the given plan.
* Un-cache the given plan or all the cache entries that refer to the given plan.
* @param spark The Spark session.
* @param plan The plan to be un-cached.
* @param cascade If true, un-cache all the cache entries that refer to the given
* plan; otherwise un-cache the given plan only.
* @param blocking Whether to block until all blocks are deleted.
*/
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
def uncacheQuery(spark: SparkSession, plan: LogicalPlan,
cascade: Boolean, blocking: Boolean): Unit = writeLock {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent.

val shouldRemove: LogicalPlan => Boolean =
if (cascade) {
_.find(_.sameResult(plan)).isDefined
} else {
_.sameResult(plan)
}
val it = cachedData.iterator()
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
if (shouldRemove(cd.plan)) {
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
it.remove()
}
}
// Re-compile dependent cached queries after removing the cached query.
if (!cascade) {
val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
it.remove()
val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan
val newCache = InMemoryRelation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, if the plan to uncache is iterated after a plan containing it, doesn't this still use its cached plan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, although it wouldn't lead to any error just like all other compiled dataframes that refer to this old InMemoryRelation. I'll change this piece of code. But you've brought out another interesting question:
A scenario similar to what you've mentioned:

df2 = df1.filter(...)
df2.cache()
df1.cache()
df1.collect()

, which means we cache the dependent cache first and the cache being depended upon next. Optimally when you do df2.collect(), you would like df2 to use the cached data in df1, but it doesn't work like this now since df2's execution plan has already been generated before we call df1.cache(). It might be worth revisiting the caches and update their plans if necessary when we call cacheQuery()

cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan),
logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
needToRecache.foreach(cachedData.add)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a private function from line 144 and line 158?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's almost the same logic as "recache", except that it tries to reuse the cached buffer here. It would be nice to integrate these two, but it would look so clean given the inconvenience of copying a CacheBuilder. I'll try though.

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ case class CachedRDDBuilder(
}
}

def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = {
new CachedRDDBuilder(
useCompression,
batchSize,
storageLevel,
cachedPlan = cachedPlan,
tableName
)(_cachedColumnBuffers)
}

private def buildBuffers(): RDD[CachedBatch] = {
val output = cachedPlan.output
val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ case class DropTableCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val isTempTable = catalog.isTemporaryTable(tableName)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it to isTempView


if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) {
if (!isTempTable && catalog.tableExists(tableName)) {
// If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view
// issue an exception.
catalog.getTableMetadata(tableName).tableType match {
Expand All @@ -204,9 +205,10 @@ case class DropTableCommand(
}
}

if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) {
if (isTempTable || catalog.tableExists(tableName)) {
try {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName), !isTempTable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cascade = !isTempTable

} catch {
case NonFatal(e) => log.warn(e.toString, e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ case class TruncateTableCommand(
spark.sessionState.refreshTable(tableName.unquotedString)
// Also try to drop the contents of the table from the columnar cache
try {
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier))
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

named argument. cascade = true

} catch {
case NonFatal(e) =>
log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession, viewDef, cascade = false, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
Expand All @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession, viewDef, cascade = false, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
Expand Down Expand Up @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName), !sessionCatalog.isTemporaryTable(tableIdent))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val cascade = !sessionCatalog.isTemporaryTable(tableIdent)

...

}

/**
Expand Down Expand Up @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
// cached version and make the new version cached lazily.
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, true, blocking = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same here.

// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
Expand Down Expand Up @@ -801,4 +800,67 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}
assert(cachedData.collect === Seq(1001))
}

test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") {
withView("t1", "t2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withTempView

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.. good catch! A mistake caused by copy-paste.

sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("UNCACHE TABLE t1")
assert(!spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") {
withView("t1", "t2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("DROP VIEW t1")
assert(spark.catalog.isCached("t2"))
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") {
withTable("t") {
spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
.write.format("json").saveAsTable("t")
withView("t1", "t2") {
sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1")

sql("CACHE TABLE t1")
sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("DROP VIEW t1")
assert(!spark.catalog.isCached("t2"))
}
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") {
withTable("t") {
spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
.write.format("json").saveAsTable("t")
withView("t1", "t2") {
sql("CACHE TABLE t")
sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1")
sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t"))
assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("UNCACHE TABLE t")
assert(!spark.catalog.isCached("t"))
assert(!spark.catalog.isCached("t1"))
assert(!spark.catalog.isCached("t2"))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ import org.apache.spark.storage.StorageLevel
class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits {
import testImplicits._

/**
* Asserts that a cached [[Dataset]] will be built using the given number of other cached results.
*/
private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = {
val plan = df.queryExecution.withCachedData
assert(plan.isInstanceOf[InMemoryRelation])
val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan
assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon)
}

test("get storage level") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
Expand Down Expand Up @@ -117,7 +127,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits
}

test("cache UDF result correctly") {
val expensiveUDF = udf({x: Int => Thread.sleep(10000); x})
val expensiveUDF = udf({x: Int => Thread.sleep(5000); x})
val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a"))
val df2 = df.agg(sum(df("b")))

Expand All @@ -126,7 +136,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits
assertCached(df2)

// udf has been evaluated during caching, and thus should not be re-evaluated here
failAfter(5 seconds) {
failAfter(3 seconds) {
df2.collect()
}

Expand All @@ -143,9 +153,57 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits
df.count()
df2.cache()

val plan = df2.queryExecution.withCachedData
assert(plan.isInstanceOf[InMemoryRelation])
val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan
assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined)
assertCacheDependency(df2)
}

test("SPARK-24596 Non-cascading Cache Invalidation") {
val df = Seq(("a", 1), ("b", 2)).toDF("s", "i")
val df2 = df.filter('i > 1)
val df3 = df.filter('i < 2)

df2.cache()
df.cache()
df.count()
df3.cache()

df.unpersist()

// df un-cached; df2 and df3's cache plan re-compiled
assert(df.storageLevel == StorageLevel.NONE)
assertCacheDependency(df2, 0)
assertCacheDependency(df3, 0)
}

test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") {
val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x })
val df = spark.range(0, 10).toDF("a")
val df1 = df.withColumn("b", expensiveUDF($"a"))
val df2 = df1.groupBy('a).agg(sum('b))
val df3 = df.agg(sum('a))

df1.cache()
df2.cache()
df2.collect()
df3.cache()

assertCacheDependency(df2)

df1.unpersist(blocking = true)

// df1 un-cached; df2's cache plan re-compiled
assert(df1.storageLevel == StorageLevel.NONE)
assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0)

val df4 = df1.groupBy('a).agg(sum('b)).select("sum(b)")
assertCached(df4)
// reuse loaded cache
failAfter(3 seconds) {
df4.collect()
}

val df5 = df.agg(sum('a)).filter($"sum(a)" > 1)
assertCached(df5)
// first time use, load cache
df5.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we prove this takes more than 5 seconds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to prove the new InMemoryRelation works alright for building cache (since the plan has been re-compiled) ... maybe we should check result though. Plus, I deliberately made this dataframe not dependent on the UDF so it can finish quickly.

}
}