Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,21 @@ class Dataset[T] private[sql](
tEncoder: Encoder[T]) extends Queryable with Serializable {

/**
* An unresolved version of the internal encoder for the type of this dataset. This one is marked
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
* object type (that will be possibly resolved to a different schema).
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
* marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
* same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)

/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)

/**
* The encoder where the expressions used to construct an object from an input row have been
* bound to the ordinals of the given schema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'm going to change this to say this [[Dataset]]'s output schema

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, actually i forgot :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add the change in a follow-up PR. : )

*/
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)

private implicit def classTag = resolvedTEncoder.clsTag

Expand All @@ -89,7 +95,7 @@ class Dataset[T] private[sql](
override def schema: StructType = resolvedTEncoder.schema

/**
* Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
* Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
Expand All @@ -111,7 +117,7 @@ class Dataset[T] private[sql](
* ************* */

/**
* Returns a new `Dataset` where each record has been mapped on to the specified type. The
* Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`)
Expand Down Expand Up @@ -145,23 +151,20 @@ class Dataset[T] private[sql](
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)

/**
* Returns this Dataset.
* Returns this [[Dataset]].
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
def toDS(): Dataset[T] = this

/**
* Converts this Dataset to an RDD.
* Converts this [[Dataset]] to an [[RDD]].
* @since 1.6.0
*/
def rdd: RDD[T] = {
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
iter.map(bound.fromRow)
iter.map(boundTEncoder.fromRow)
}
}

Expand Down Expand Up @@ -189,15 +192,15 @@ class Dataset[T] private[sql](
def show(numRows: Int): Unit = show(numRows, truncate = true)

/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
* Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @since 1.6.0
*/
def show(): Unit = show(20)

/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form.
* Displays the top 20 rows of [[Dataset]] in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
Expand All @@ -207,7 +210,7 @@ class Dataset[T] private[sql](
def show(truncate: Boolean): Unit = show(20, truncate)

/**
* Displays the [[DataFrame]] in a tabular form. For example:
* Displays the [[Dataset]] in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
Expand Down Expand Up @@ -291,7 +294,7 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
Expand All @@ -307,7 +310,7 @@ class Dataset[T] private[sql](

/**
* (Java-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
Expand Down Expand Up @@ -341,28 +344,28 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)

/**
* (Java-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))

/**
* (Scala-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)

/**
* (Java-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
Expand All @@ -374,27 +377,27 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: (T, T) => T): T = rdd.reduce(func)

/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))

/**
* (Scala-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)

Expand Down Expand Up @@ -429,18 +432,18 @@ class Dataset[T] private[sql](

/**
* (Java-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note: this is fine since RC1 failed, but we can't make these kinds of changes in the future as they break compatibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, next time, I will be careful. Thanks!

groupBy(func.call(_))(encoder)

/* ****************** *
* Typed Relational *
* ****************** */

/**
* Selects a set of column based expressions.
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
Expand All @@ -464,8 +467,8 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
boundTEncoder,
logicalPlan.output).named :: Nil,
logicalPlan))
}

Expand All @@ -477,7 +480,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))

new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
Expand Down Expand Up @@ -654,25 +657,22 @@ class Dataset[T] private[sql](
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
val bound = tEnc.bind(input)
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
}

/**
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
Expand All @@ -683,7 +683,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
Expand All @@ -692,7 +692,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
Expand Down