diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7e5ab19bf84..87dae6b33159 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,14 +21,15 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.types.StructType /** @@ -359,7 +360,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan)) } /** @@ -368,11 +369,12 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } + val withEncoders = columns.map(withEncoder) + val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { + val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { case (e: ExpressionEncoder[_], a) if !e.flat => e.nested(a.toAttribute).resolve(execution.analyzed.output) case (e, a) => @@ -381,6 +383,16 @@ class Dataset[T] private[sql]( new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } + private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = { + val e = c.expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]), + children = queryExecution.analyzed.output) + } + new TypedColumn(e, c.encoder) + } + /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 206095a51976..6c89a5fb1d34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ComplexResultAgg.toColumn), ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } }