Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
01e4cdf
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 13, 2015
6835704
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
9180687
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 14, 2015
b38a21e
SPARK-11633
gatorsmile Nov 17, 2015
d2b84af
Merge remote-tracking branch 'upstream/master' into joinMakeCopy
gatorsmile Nov 17, 2015
fda8025
Merge remote-tracking branch 'upstream/master'
gatorspark Nov 17, 2015
ac0dccd
Merge branch 'master' of https://github.com/gatorsmile/spark
gatorspark Nov 17, 2015
6e0018b
Merge remote-tracking branch 'upstream/master'
Nov 20, 2015
0546772
converge
gatorsmile Nov 20, 2015
b37a64f
converge
gatorsmile Nov 20, 2015
f061671
Support Persist/Cache and Unpersist in DataSet APIs
gatorsmile Nov 22, 2015
88d5e9d
Merge remote-tracking branch 'upstream/master' into top
gatorsmile Nov 22, 2015
c135e1f
update the @since
gatorsmile Nov 22, 2015
661260b
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 23, 2015
2517777
adding more test cases
gatorsmile Nov 23, 2015
aa5dc52
Merge remote-tracking branch 'upstream/master' into top
gatorsmile Nov 25, 2015
2dfa0fd
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 25, 2015
c4489ed
Merge remote-tracking branch 'upstream/master' into top
gatorsmile Nov 25, 2015
683fa6f
resolved all the comments
gatorsmile Nov 25, 2015
1c82396
Merge remote-tracking branch 'upstream/master' into top
gatorsmile Nov 25, 2015
d929d9b
Merge remote-tracking branch 'upstream/master'
gatorsmile Nov 25, 2015
92ede39
Merge branch 'top' into persistDSmerge
gatorsmile Nov 25, 2015
8071d30
Merge remote-tracking branch 'upstream/master' into persistDSmerge
gatorsmile Dec 1, 2015
b9518ee
updated the codes based on the review comments from Michale Armbrust.
gatorsmile Dec 1, 2015
b8d287a
Changed the name from CacheSuite.scala to DatasetCacheSuite.scala
gatorsmile Dec 1, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import org.apache.spark.storage.StorageLevel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order imports.


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand Down Expand Up @@ -461,7 +463,7 @@ class Dataset[T] private[sql](
* combined.
*
* Note that, this function is not a typical set union operation, in that it does not eliminate
* duplicate items. As such, it is analagous to `UNION ALL` in SQL.
* duplicate items. As such, it is analogous to `UNION ALL` in SQL.
* @since 1.6.0
*/
def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
Expand Down Expand Up @@ -510,7 +512,6 @@ class Dataset[T] private[sql](
case _ => Alias(CreateStruct(rightOutput), "_2")()
}


implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
withPlan[(T, U)](other) { (left, right) =>
Expand Down Expand Up @@ -579,11 +580,50 @@ class Dataset[T] private[sql](
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)


/* ******* *
* Cache *
* ******* */

/**
* @since 1.6.0
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment style here is off and we should actually have a description. Could we just move the functions/docs from DataFrame to Queryable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, we are unable to move the functions to Queryable because the types of the returned values are different. I just added the descriptions in both DataFrame and Dataset. Hopefully, it resolves your concern. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus moving functions into Queryable actually breaks both scaladoc and javadoc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin I think thats only because we explicitly exclude execution from scaladoc. Maybe we should move queryable? or don't exclude that class. I don't want to duplicate a ton of docs.

def persist(): this.type = {
sqlContext.cacheManager.cacheQuery(this)
this
}

/**
* @since 1.6.0
*/
def cache(): this.type = persist()

/**
* @since 1.6.0
*/
def persist(newLevel: StorageLevel): this.type = {
sqlContext.cacheManager.cacheQuery(this, None, newLevel)
this
}

/**
* @since 1.6.0
*/
def unpersist(blocking: Boolean): this.type = {
sqlContext.cacheManager.tryUncacheQuery(this, blocking)
this
}

/**
* @since 1.6.0
*/
def unpersist(): this.type = unpersist(blocking = false)

/* ******************** *
* Internal Functions *
* ******************** */

private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def logicalPlan : LogicalPlan = queryExecution.analyzed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No space before :.


private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock

import org.apache.spark.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -75,12 +75,12 @@ private[sql] class CacheManager extends Logging {
}

/**
* Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike
* `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
* the in-memory columnar representation of the underlying table is expensive.
* Caches the data produced by the logical representation of the given [[DataFrame]]/[[Dataset]].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit, but you could probably just say [[Queryable]].

* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
* recomputing the in-memory columnar representation of the underlying table is expensive.
*/
private[sql] def cacheQuery(
query: DataFrame,
query: Queryable,
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
Expand All @@ -100,7 +100,7 @@ private[sql] class CacheManager extends Logging {
}
}

/** Removes the data for the given [[DataFrame]] from the cache */
/** Removes the data for the given [[DataFrame]]/[[Dataset]] from the cache */
private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
Expand All @@ -109,9 +109,11 @@ private[sql] class CacheManager extends Logging {
cachedData.remove(dataIndex)
}

/** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */
/** Tries to remove the data for the given [[DataFrame]]/[[Dataset]] from the cache
* if it's cached
*/
private[sql] def tryUncacheQuery(
query: DataFrame,
query: Queryable,
blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
Expand All @@ -123,8 +125,8 @@ private[sql] class CacheManager extends Logging {
found
}

/** Optionally returns cached data for the given [[DataFrame]] */
private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
/** Optionally returns cached data for the given [[DataFrame]]/[[Dataset]] */
private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StructType

import scala.util.control.NonFatal
Expand All @@ -27,6 +28,7 @@ private[sql] trait Queryable {
def schema: StructType
def queryExecution: QueryExecution
def sqlContext: SQLContext
private[sql] def logicalPlan: LogicalPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just get this from the queryExecution in the cache manager? or at least define it explicitly here. I don't want dataframes and datasets to fall out of sync with regards to what the canonical plan phase is.


override def toString: String = {
try {
Expand Down
52 changes: 52 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql

import java.io.{ObjectInput, ObjectOutput, Externalizable}

import org.apache.spark.sql.execution.columnar.InMemoryRelation

import scala.language.postfixOps

import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -213,6 +215,56 @@ class DatasetSuite extends QueryTest with SharedSQLContext {

}

test("persist and unpersist") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd consider putting these in their own suite. Caching is a pretty isolated concern. I would also like to see some more tests for operations that aren't expressed as logical operations (i.e. map/filter with lambda functions instead of expressions).

val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
val cached = ds.cache()
// count triggers the caching action. It should not throw.
cached.count()
// Make sure, the Dataset is indeed cached.
assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty)
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
cached.queryExecution.withCachedData.collect {
case r: InMemoryRelation => r
}.size
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets just generalize and use assertCached which should already be in scope.

// Check result.
checkAnswer(
cached,
2, 3, 4)
// Drop the cache.
cached.unpersist()
}

test("persist and then rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")

ds1.persist()
ds2.persist()

val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
checkAnswer(joined, ("2", 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assertCached in every test that checks caching.


ds1.unpersist()
ds2.unpersist()
}

test("persist and then groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds.persist()

val grouped = ds.groupBy($"_1").keyAs[String]
val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
agged.persist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a little odd to have a test that persists more than one point in the lineage, unless you are explicitly testing that only the latest possible set of materialized data is being used.


checkAnswer(
agged.filter(_._1 == "b"),
("b", 3))

ds.unpersist()
agged.unpersist()
}

test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupBy(v => (1, v._2))
Expand Down