-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11905] [SQL] Support Persist/Cache and Unpersist in Dataset APIs #9889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
01e4cdf
6835704
9180687
b38a21e
d2b84af
fda8025
ac0dccd
6e0018b
0546772
b37a64f
f061671
88d5e9d
c135e1f
661260b
2517777
aa5dc52
2dfa0fd
c4489ed
683fa6f
1c82396
d929d9b
92ede39
8071d30
b9518ee
b8d287a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.storage.StorageLevel | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
|
|
@@ -461,7 +463,7 @@ class Dataset[T] private[sql]( | |
| * combined. | ||
| * | ||
| * Note that, this function is not a typical set union operation, in that it does not eliminate | ||
| * duplicate items. As such, it is analagous to `UNION ALL` in SQL. | ||
| * duplicate items. As such, it is analogous to `UNION ALL` in SQL. | ||
| * @since 1.6.0 | ||
| */ | ||
| def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) | ||
|
|
@@ -510,7 +512,6 @@ class Dataset[T] private[sql]( | |
| case _ => Alias(CreateStruct(rightOutput), "_2")() | ||
| } | ||
|
|
||
|
|
||
| implicit val tuple2Encoder: Encoder[(T, U)] = | ||
| ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) | ||
| withPlan[(T, U)](other) { (left, right) => | ||
|
|
@@ -579,11 +580,50 @@ class Dataset[T] private[sql]( | |
| */ | ||
| def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) | ||
|
|
||
|
|
||
| /* ******* * | ||
| * Cache * | ||
| * ******* */ | ||
|
|
||
| /** | ||
| * @since 1.6.0 | ||
| */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment style here is off and we should actually have a description. Could we just move the functions/docs from DataFrame to Queryable?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far, we are unable to move the functions to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @marmbrus moving functions into Queryable actually breaks both scaladoc and javadoc.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rxin I think thats only because we explicitly exclude execution from scaladoc. Maybe we should move queryable? or don't exclude that class. I don't want to duplicate a ton of docs. |
||
| def persist(): this.type = { | ||
| sqlContext.cacheManager.cacheQuery(this) | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * @since 1.6.0 | ||
| */ | ||
| def cache(): this.type = persist() | ||
|
|
||
| /** | ||
| * @since 1.6.0 | ||
| */ | ||
| def persist(newLevel: StorageLevel): this.type = { | ||
| sqlContext.cacheManager.cacheQuery(this, None, newLevel) | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * @since 1.6.0 | ||
| */ | ||
| def unpersist(blocking: Boolean): this.type = { | ||
| sqlContext.cacheManager.tryUncacheQuery(this, blocking) | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * @since 1.6.0 | ||
| */ | ||
| def unpersist(): this.type = unpersist(blocking = false) | ||
|
|
||
| /* ******************** * | ||
| * Internal Functions * | ||
| * ******************** */ | ||
|
|
||
| private[sql] def logicalPlan = queryExecution.analyzed | ||
| private[sql] def logicalPlan : LogicalPlan = queryExecution.analyzed | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No space before |
||
|
|
||
| private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = | ||
| new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution | |
| import java.util.concurrent.locks.ReentrantReadWriteLock | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
| import org.apache.spark.sql.execution.columnar.InMemoryRelation | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
@@ -75,12 +75,12 @@ private[sql] class CacheManager extends Logging { | |
| } | ||
|
|
||
| /** | ||
| * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike | ||
| * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing | ||
| * the in-memory columnar representation of the underlying table is expensive. | ||
| * Caches the data produced by the logical representation of the given [[DataFrame]]/[[Dataset]]. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a nit, but you could probably just say |
||
| * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because | ||
| * recomputing the in-memory columnar representation of the underlying table is expensive. | ||
| */ | ||
| private[sql] def cacheQuery( | ||
| query: DataFrame, | ||
| query: Queryable, | ||
| tableName: Option[String] = None, | ||
| storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { | ||
| val planToCache = query.queryExecution.analyzed | ||
|
|
@@ -100,7 +100,7 @@ private[sql] class CacheManager extends Logging { | |
| } | ||
| } | ||
|
|
||
| /** Removes the data for the given [[DataFrame]] from the cache */ | ||
| /** Removes the data for the given [[DataFrame]]/[[Dataset]] from the cache */ | ||
| private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { | ||
| val planToCache = query.queryExecution.analyzed | ||
| val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) | ||
|
|
@@ -109,9 +109,11 @@ private[sql] class CacheManager extends Logging { | |
| cachedData.remove(dataIndex) | ||
| } | ||
|
|
||
| /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ | ||
| /** Tries to remove the data for the given [[DataFrame]]/[[Dataset]] from the cache | ||
| * if it's cached | ||
| */ | ||
| private[sql] def tryUncacheQuery( | ||
| query: DataFrame, | ||
| query: Queryable, | ||
| blocking: Boolean = true): Boolean = writeLock { | ||
| val planToCache = query.queryExecution.analyzed | ||
| val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) | ||
|
|
@@ -123,8 +125,8 @@ private[sql] class CacheManager extends Logging { | |
| found | ||
| } | ||
|
|
||
| /** Optionally returns cached data for the given [[DataFrame]] */ | ||
| private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { | ||
| /** Optionally returns cached data for the given [[DataFrame]]/[[Dataset]] */ | ||
| private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { | ||
| lookupCachedData(query.queryExecution.analyzed) | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.sql.SQLContext | ||
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| import scala.util.control.NonFatal | ||
|
|
@@ -27,6 +28,7 @@ private[sql] trait Queryable { | |
| def schema: StructType | ||
| def queryExecution: QueryExecution | ||
| def sqlContext: SQLContext | ||
| private[sql] def logicalPlan: LogicalPlan | ||
|
||
|
|
||
| override def toString: String = { | ||
| try { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,8 @@ package org.apache.spark.sql | |
|
|
||
| import java.io.{ObjectInput, ObjectOutput, Externalizable} | ||
|
|
||
| import org.apache.spark.sql.execution.columnar.InMemoryRelation | ||
|
|
||
| import scala.language.postfixOps | ||
|
|
||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -213,6 +215,56 @@ class DatasetSuite extends QueryTest with SharedSQLContext { | |
|
|
||
| } | ||
|
|
||
| test("persist and unpersist") { | ||
|
||
| val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) | ||
| val cached = ds.cache() | ||
| // count triggers the caching action. It should not throw. | ||
| cached.count() | ||
| // Make sure, the Dataset is indeed cached. | ||
| assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) | ||
| assertResult(1, "InMemoryRelation not found, testData should have been cached") { | ||
| cached.queryExecution.withCachedData.collect { | ||
| case r: InMemoryRelation => r | ||
| }.size | ||
| } | ||
|
||
| // Check result. | ||
| checkAnswer( | ||
| cached, | ||
| 2, 3, 4) | ||
| // Drop the cache. | ||
| cached.unpersist() | ||
| } | ||
|
|
||
| test("persist and then rebind right encoder when join 2 datasets") { | ||
| val ds1 = Seq("1", "2").toDS().as("a") | ||
| val ds2 = Seq(2, 3).toDS().as("b") | ||
|
|
||
| ds1.persist() | ||
| ds2.persist() | ||
|
|
||
| val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") | ||
| checkAnswer(joined, ("2", 2)) | ||
|
||
|
|
||
| ds1.unpersist() | ||
| ds2.unpersist() | ||
| } | ||
|
|
||
| test("persist and then groupBy columns asKey, map") { | ||
| val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() | ||
| ds.persist() | ||
|
|
||
| val grouped = ds.groupBy($"_1").keyAs[String] | ||
| val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } | ||
| agged.persist() | ||
|
||
|
|
||
| checkAnswer( | ||
| agged.filter(_._1 == "b"), | ||
| ("b", 3)) | ||
|
|
||
| ds.unpersist() | ||
| agged.unpersist() | ||
| } | ||
|
|
||
| test("groupBy function, keys") { | ||
| val ds = Seq(("a", 1), ("b", 1)).toDS() | ||
| val grouped = ds.groupBy(v => (1, v._2)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
order imports.