diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b3e4688557ba..b8b6adf41f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -162,6 +162,10 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) protected override def doExecute(): RDD[InternalRow] = { @@ -200,6 +204,15 @@ case class TakeOrderedAndProject( projectOutput.getOrElse(child.output) } + override def outputsUnsafeRows: Boolean = if (projectList.isDefined) { + true + } else { + child.outputsUnsafeRows + } + + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def outputPartitioning: Partitioning = SinglePartition // We need to use an interpreted ordering here because generated orderings cannot be serialized @@ -207,11 +220,15 @@ case class TakeOrderedAndProject( private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + @transient private val projection = projectList.map(UnsafeProjection.create(_, child.output)) private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + if (projection.isDefined) { + projection.map(p => data.map(p(_).copy().asInstanceOf[InternalRow])).get + } else { + data + } } override def executeCollect(): Array[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 2328899bb2f8..27192d6f2d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -26,6 +26,15 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String +case class DummySafeNode(limit: Int, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + protected override def doExecute(): RDD[InternalRow] = child.execute() +} + class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { @@ -39,7 +48,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) + val plan = DummySafeNode(10, outputsUnsafe) val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) }