Skip to content
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2971,7 +2971,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def unpersist(blocking: Boolean): this.type = {
sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking)
sparkSession.sharedState.cacheManager.uncacheQuery(this, false, blocking)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's clearer to write cascade =false

this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,35 @@ class CacheManager extends Logging {
/**
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
def uncacheQuery(query: Dataset[_],
cascade: Boolean, blocking: Boolean = true): Unit = writeLock {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

def f(
    param1: X,
    param2: Y)....

4 space indentation.

uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking)
}

/**
* Un-cache all the cache entries that refer to the given plan.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update this document.

*/
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
def uncacheQuery(spark: SparkSession, plan: LogicalPlan,
cascade: Boolean, blocking: Boolean): Unit = writeLock {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
it.remove()
if (cascade || cd.plan.sameResult(plan)) {
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
} else {
val plan = spark.sessionState.executePlan(cd.plan).executedPlan
val newCache = InMemoryRelation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, if the plan to uncache is iterated after a plan containing it, doesn't this still use its cached plan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, although it wouldn't lead to any error just like all other compiled dataframes that refer to this old InMemoryRelation. I'll change this piece of code. But you've brought out another interesting question:
A scenario similar to what you've mentioned:

df2 = df1.filter(...)
df2.cache()
df1.cache()
df1.collect()

, which means we cache the dependent cache first and the cache being depended upon next. Optimally when you do df2.collect(), you would like df2 to use the cached data in df1, but it doesn't work like this now since df2's execution plan has already been generated before we call df1.cache(). It might be worth revisiting the caches and update their plans if necessary when we call cacheQuery()

cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan),
logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
}

needToRecache.foreach(cachedData.add)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ case class CachedRDDBuilder(
}
}

def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = {
new CachedRDDBuilder(
useCompression,
batchSize,
storageLevel,
cachedPlan = cachedPlan,
tableName
)(_cachedColumnBuffers)
}

private def buildBuffers(): RDD[CachedBatch] = {
val output = cachedPlan.output
val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ case class DropTableCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val isTempTable = catalog.isTemporaryTable(tableName)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it to isTempView


if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) {
if (!isTempTable && catalog.tableExists(tableName)) {
// If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view
// issue an exception.
catalog.getTableMetadata(tableName).tableType match {
Expand All @@ -204,9 +205,10 @@ case class DropTableCommand(
}
}

if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) {
if (isTempTable || catalog.tableExists(tableName)) {
try {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName), !isTempTable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cascade = !isTempTable

} catch {
case NonFatal(e) => log.warn(e.toString, e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ case class TruncateTableCommand(
spark.sessionState.refreshTable(tableName.unquotedString)
// Also try to drop the contents of the table from the columnar cache
try {
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier))
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

named argument. cascade = true

} catch {
case NonFatal(e) =>
log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession, viewDef, cascade = false, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
Expand All @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession, viewDef, cascade = false, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
Expand Down Expand Up @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName), !sessionCatalog.isTemporaryTable(tableIdent))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val cascade = !sessionCatalog.isTemporaryTable(tableIdent)

...

}

/**
Expand Down Expand Up @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
// cached version and make the new version cached lazily.
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, true, blocking = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same here.

// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
Expand Down Expand Up @@ -801,4 +800,67 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
}
assert(cachedData.collect === Seq(1001))
}

test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") {
withView("t1", "t2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withTempView

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.. good catch! A mistake caused by copy-paste.

sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("UNCACHE TABLE t1")
assert(!spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") {
withView("t1", "t2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("DROP VIEW t1")
assert(spark.catalog.isCached("t2"))
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") {
withTable("t") {
spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
.write.format("json").saveAsTable("t")
withView("t1", "t2") {
sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1")

sql("CACHE TABLE t1")
sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("DROP VIEW t1")
assert(!spark.catalog.isCached("t2"))
}
}
}

test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") {
withTable("t") {
spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
.write.format("json").saveAsTable("t")
withView("t1", "t2") {
sql("CACHE TABLE t")
sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1")
sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")

assert(spark.catalog.isCached("t"))
assert(spark.catalog.isCached("t1"))
assert(spark.catalog.isCached("t2"))
sql("UNCACHE TABLE t")
assert(!spark.catalog.isCached("t"))
assert(!spark.catalog.isCached("t1"))
assert(!spark.catalog.isCached("t2"))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,33 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits
df.unpersist()
assert(df.storageLevel == StorageLevel.NONE)
}

test("SPARK-24596 Non-cascading Cache Invalidation") {
val expensiveUDF = udf({x: Int => Thread.sleep(10000); x})
val df = spark.range(0, 10).toDF("a")
val df1 = df.withColumn("b", expensiveUDF($"a"))
val df2 = df1.groupBy('a).agg(sum('b))
val df3 = df.agg(sum('a))

df1.cache()
df2.cache()
df2.collect()
df3.cache()

df1.unpersist(blocking = true)

assert(df1.storageLevel == StorageLevel.NONE)

val df4 = df1.groupBy('a).agg(sum('b)).select("sum(b)")
assertCached(df4)
// reuse loaded cache
failAfter(5 seconds) {
df4.collect()
}

val df5 = df.agg(sum('a)).filter($"sum(a)" > 1)
assertCached(df5)
// first time use, load cache
df5.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we prove this takes more than 5 seconds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to prove the new InMemoryRelation works alright for building cache (since the plan has been re-compiled) ... maybe we should check result though. Plus, I deliberately made this dataframe not dependent on the UDF so it can finish quickly.

}
}