From ca4326ab0c4d26af3eddcc9550624db91d195e8a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 13 Dec 2015 23:46:13 +0800 Subject: [PATCH 1/6] Support UnsafeRow in LocalTableScan. --- .../sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../plans/logical/LocalRelation.scala | 22 ++++++++++++------ .../ConvertToLocalRelationSuite.scala | 4 ++-- .../org/apache/spark/sql/SQLContext.scala | 8 +++++-- .../spark/sql/execution/LocalTableScan.scala | 15 +++++++++--- .../sql/execution/stat/StatFunctions.scala | 3 ++- .../execution/RowFormatConvertersSuite.scala | 5 ++-- .../spark/sql/execution/local/DummyNode.scala | 23 ++++++++++++------- .../sql/execution/local/ExpandNodeSuite.scala | 2 +- .../sql/execution/local/FilterNodeSuite.scala | 2 +- .../execution/local/HashJoinNodeSuite.scala | 4 ++-- .../execution/local/IntersectNodeSuite.scala | 4 ++-- .../sql/execution/local/LimitNodeSuite.scala | 2 +- .../sql/execution/local/LocalNodeSuite.scala | 6 ++--- .../local/NestedLoopJoinNodeSuite.scala | 4 ++-- .../execution/local/ProjectNodeSuite.scala | 2 +- .../sql/execution/local/SampleNodeSuite.scala | 2 +- .../TakeOrderedAndProjectNodeSuite.scala | 2 +- .../sql/execution/local/UnionNodeSuite.scala | 2 +- 19 files changed, 72 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a927..ed326b635ff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -937,8 +937,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, LocalRelation(output, data)) => - val projection = new InterpretedProjection(projectList, output) - LocalRelation(projectList.map(_.toAttribute), data.map(projection)) + val projection = UnsafeProjection.create(projectList, output) + LocalRelation(projectList.map(_.toAttribute), data.map(projection(_).asInstanceOf[UnsafeRow])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e7a11dba97..5e56ce10083c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} import org.apache.spark.sql.types.{StructField, StructType} @@ -29,20 +30,27 @@ object LocalRelation { new LocalRelation(StructType(output1 +: output).toAttributes) } + def fromInternalRows(output: Seq[Attribute], data: Seq[InternalRow]): LocalRelation = { + val projection = UnsafeProjection.create(output.map(_.dataType).toArray) + new LocalRelation(output, data.map(projection(_).copy())) + } + def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) + val encoder = RowEncoder(schema) + LocalRelation(output, data.map(encoder.toRow(_).copy().asInstanceOf[UnsafeRow])) } - def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { + def fromProduct[T <: Product : ExpressionEncoder]( + output: Seq[Attribute], + data: Seq[T]): LocalRelation = { + val encoder = implicitly[ExpressionEncoder[T]] val schema = StructType.fromAttributes(output) - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) + new LocalRelation(output, data.map(encoder.toRow(_).copy().asInstanceOf[UnsafeRow])) } } -case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) +case class LocalRelation(output: Seq[Attribute], data: Seq[UnsafeRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 049a19b86f7c..92fe9261cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -35,11 +35,11 @@ class ConvertToLocalRelationSuite extends PlanTest { } test("Project on LocalRelation should be turned into a single LocalRelation") { - val testRelation = LocalRelation( + val testRelation = LocalRelation.fromInternalRows( LocalRelation('a.int, 'b.int).output, InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) - val correctAnswer = LocalRelation( + val correctAnswer = LocalRelation.fromInternalRows( LocalRelation('a1.int, 'b1.int).output, InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index db286ea8700b..02f244fd9d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -426,6 +427,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes @@ -501,7 +503,7 @@ class SQLContext private[sql]( def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d).copy()) + val encoded = data.map(d => enc.toRow(d).copy().asInstanceOf[UnsafeRow]) val plan = new LocalRelation(attributes, encoded) new Dataset[T](this, plan) @@ -604,7 +606,9 @@ class SQLContext private[sql]( val className = beanClass.getName val beanInfo = Introspector.getBeanInfo(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) + val projection = UnsafeProjection.create(attrSeq) + DataFrame(self, + LocalRelation(attrSeq, rows.toSeq.map(projection(_).copy().asInstanceOf[UnsafeRow]))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c..c62b0c9b9631 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,17 +19,26 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, UnsafeProjection, UnsafeRow} +private[sql] object LocalTableScan { + def fromInternalRows(output: Seq[Attribute], data: Seq[InternalRow]): LocalTableScan = { + val projection = UnsafeProjection.create(output.map(_.dataType).toArray) + new LocalTableScan(output, data.map(projection(_).copy())) + } +} /** * Physical plan node for scanning data from a local collection. */ private[sql] case class LocalTableScan( output: Seq[Attribute], - rows: Seq[InternalRow]) extends LeafNode { + rows: Seq[UnsafeRow]) extends LeafNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private lazy val rdd = sqlContext.sparkContext.parallelize(rows).asInstanceOf[RDD[InternalRow]] protected override def doExecute(): RDD[InternalRow] = rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7d5..6269bc7054c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -144,6 +144,7 @@ private[sql] object StatFunctions extends Logging { } val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + new DataFrame(df.sqlContext, + LocalRelation.fromInternalRows(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 13d68a103a22..85632e98fbc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -99,12 +99,11 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + val relation = LocalTableScan.fromInternalRows(Seq(AttributeReference("t", schema)()), rows) val plan = DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) + ConvertToSafe(relation)) assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala index efc3227dd60d..a8edb1991dd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -17,11 +17,26 @@ package org.apache.spark.sql.execution.local +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +private[local] object DummyNode { + val CLOSED: Int = Int.MinValue + + def apply[A <: Product : TypeTag]( + output: Seq[Attribute], + data: Seq[A], + conf: SQLConf = new SQLConf): DummyNode = { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + new DummyNode(output, LocalRelation.fromProduct(output, data), conf) + } +} + /** * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. */ @@ -36,10 +51,6 @@ private[local] case class DummyNode( private var index: Int = CLOSED private val input: Seq[InternalRow] = relation.data - def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { - this(output, LocalRelation.fromProduct(output, data), conf) - } - def isOpen: Boolean = index != CLOSED override def children: Seq[LocalNode] = Seq.empty @@ -62,7 +73,3 @@ private[local] case class DummyNode( index = CLOSED } } - -private object DummyNode { - val CLOSED: Int = Int.MinValue -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index bbd94d8da2d1..286c5509270c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpandNodeSuite extends LocalNodeTest { private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) + val inputNode = DummyNode(kvIntAttributes, inputData) val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) val resolvedNode = resolveExpressions(expandNode) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 4eadce646d37..083e8f1122b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -24,7 +24,7 @@ class FilterNodeSuite extends LocalNodeTest { private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { val cond = 'k % 2 === 0 - val inputNode = new DummyNode(kvIntAttributes, inputData) + val inputNode = DummyNode(kvIntAttributes, inputData) val filterNode = new FilterNode(conf, cond, inputNode) val resolvedNode = resolveExpressions(filterNode) val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index c30327185e16..959fc28affc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -62,8 +62,8 @@ class HashJoinNodeSuite extends LocalNodeTest { // Actual test body def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { val rightInputMap = rightInput.toMap - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val leftNode = DummyNode(joinNameAttributes, leftInput) + val rightNode = DummyNode(joinNicknameAttributes, rightInput) val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { val binaryHashJoinNode = BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index c0ad2021b204..2909465dab75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -24,8 +24,8 @@ class IntersectNodeSuite extends LocalNodeTest { val n = 100 val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray - val leftNode = new DummyNode(kvIntAttributes, leftData) - val rightNode = new DummyNode(kvIntAttributes, rightData) + val leftNode = DummyNode(kvIntAttributes, leftData) + val rightNode = DummyNode(kvIntAttributes, rightData) val intersectNode = new IntersectNode(conf, leftNode, rightNode) val expectedOutput = leftData.intersect(rightData) val actualOutput = intersectNode.collect().map { case row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index fb790636a368..76e92541ad51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.execution.local class LimitNodeSuite extends LocalNodeTest { private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) + val inputNode = DummyNode(kvIntAttributes, inputData) val limitNode = new LimitNode(conf, limit, inputNode) val expectedOutput = inputData.take(limit) val actualOutput = limitNode.collect().map { case row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index 0d1ed99eec6c..eadfbd4d591f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -22,7 +22,7 @@ class LocalNodeSuite extends LocalNodeTest { private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyNode(kvIntAttributes, data) + val node = DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) @@ -42,7 +42,7 @@ class LocalNodeSuite extends LocalNodeTest { } test("asIterator") { - val node = new DummyNode(kvIntAttributes, data) + val node = DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() data.foreach { case (k, v) => @@ -61,7 +61,7 @@ class LocalNodeSuite extends LocalNodeTest { } test("collect") { - val node = new DummyNode(kvIntAttributes, data) + val node = DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 45df2ea6552d..f42127f7a435 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -47,8 +47,8 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { joinType: JoinType, leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val leftNode = DummyNode(joinNameAttributes, leftInput) + val rightNode = DummyNode(joinNicknameAttributes, rightInput) val cond = 'id1 === 'id2 val makeNode = (node1: LocalNode, node2: LocalNode) => { resolveExpressions( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 02ecb23d34b2..fdd8916ad577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -28,7 +28,7 @@ class ProjectNodeSuite extends LocalNodeTest { AttributeReference("name", StringType)()) private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { - val inputNode = new DummyNode(pieAttributes, inputData) + val inputNode = DummyNode(pieAttributes, inputData) val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) val projectNode = new ProjectNode(conf, columns, inputNode) val expectedOutput = inputData.map { case (id, age, name) => (id, name) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index a3e83bbd5145..cd0d6a4a91ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -29,7 +29,7 @@ class SampleNodeSuite extends LocalNodeTest { val maybeOut = if (withReplacement) "" else "out" test(s"with$maybeOut replacement") { val inputData = (1 to 1000).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) + val inputNode = DummyNode(kvIntAttributes, inputData) val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) val sampler = if (withReplacement) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index 42ebc7bfcaad..2b5bc6d39abb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -30,7 +30,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { val ascOrDesc = if (desc) "desc" else "asc" test(ascOrDesc) { val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) + val inputNode = DummyNode(kvIntAttributes, inputData) val firstColumn = inputNode.output(0) val sortDirection = if (desc) Descending else Ascending val sortOrder = SortOrder(firstColumn, sortDirection) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 666b0235c061..46d72c3feb0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -22,7 +22,7 @@ class UnionNodeSuite extends LocalNodeTest { private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { val inputNodes = inputData.map { data => - new DummyNode(kvIntAttributes, data) + DummyNode(kvIntAttributes, data) } val unionNode = new UnionNode(conf, inputNodes) val expectedOutput = inputData.flatten From a0a991a7a0b0a8a17e3373ffc90edd74e64514e8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Dec 2015 10:53:21 +0800 Subject: [PATCH 2/6] Add copy(). --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ed326b635ff7..74516dd16c89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -938,7 +938,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Project(projectList, LocalRelation(output, data)) => val projection = UnsafeProjection.create(projectList, output) - LocalRelation(projectList.map(_.toAttribute), data.map(projection(_).asInstanceOf[UnsafeRow])) + LocalRelation(projectList.map(_.toAttribute), + data.map(projection(_).copy().asInstanceOf[UnsafeRow])) } } From f0e6ac0f412795f0d7ef5285eea0fc06a63dc270 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Dec 2015 19:19:31 +0800 Subject: [PATCH 3/6] Fix several bugs. --- .../spark/sql/catalyst/ScalaReflection.scala | 118 +++++++++++++----- .../sql/catalyst/expressions/objects.scala | 6 +- .../sql/catalyst/util/GenericArrayData.scala | 24 +++- .../sql/execution/local/ExpandNode.scala | 8 +- .../sql/execution/stat/FrequentItems.scala | 2 +- 5 files changed, 122 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9013fd050b5f..7b0cb8094c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -61,6 +61,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) className match { @@ -177,6 +178,7 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } + val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -372,6 +374,17 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -406,11 +419,16 @@ object ScalaReflection extends ScalaReflection { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { val externalDataType = dataTypeFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(catalystType)) { - NewInstance( + + if (isNativeType(catalystType) && !(elementType <:< localTypeOf[Option[_]])) { + val array = NewInstance( classOf[GenericArrayData], input :: Nil, dataType = ArrayType(catalystType, nullable)) + expressions.If( + IsNull(input), + expressions.Literal.create(null, ArrayType(catalystType, nullable)), + array) } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath @@ -421,46 +439,75 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { + val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t optType match { // For primitive types we must manually unbox the value of the object. case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "intValue", + IntegerType)) case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "longValue", + LongType)) case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "doubleValue", + DoubleType)) case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "floatValue", + FloatType)) case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "shortValue", + ShortType)) case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "byteValue", + ByteType)) case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) + val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + Invoke( + unwrapped, + "booleanValue", + BooleanType)) // For non-primitives, we can just extract the object from the Option and then recurse. case other => @@ -589,6 +636,17 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b2facfda2444..e6df12a1652f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -463,7 +463,11 @@ case class MapObjects( $convertedArray[$loopIndex] = null; } else { ${genFunction.code} - $convertedArray[$loopIndex] = ${genFunction.value}; + if (${genFunction.isNull}) { + $convertedArray[$loopIndex] = null; + } else { + $convertedArray[$loopIndex] = ${genFunction.value}; + } } $loopIndex += 1; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 2b8cdc1e23ab..84496a57a2aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -18,14 +18,30 @@ package org.apache.spark.sql.catalyst.util import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +object GenericArrayData { + def processSeq(seq: Seq[Any]): Array[Any] = { + seq match { + case wArray: WrappedArray[_] => + if (wArray.array == null) { + null + } else { + wArray.toArray[Any] + } + case null => null + case _ => seq.toArray + } + } +} + class GenericArrayData(val array: Array[Any]) extends ArrayData { - def this(seq: Seq[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(GenericArrayData.processSeq(seq)) def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. @@ -39,7 +55,11 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def copy(): ArrayData = new GenericArrayData(array.clone()) - override def numElements(): Int = array.length + override def numElements(): Int = if (array != null) { + array.length + } else { + 0 + } private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T] override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 2aff156d18b5..c0d742c78f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection, UnsafeProjection} case class ExpandNode( conf: SQLConf, @@ -29,6 +29,10 @@ case class ExpandNode( assert(projections.size > 0) + override def canProcessUnsafeRows: Boolean = true + + override def outputsUnsafeRows: Boolean = true + private[this] var result: InternalRow = _ private[this] var idx: Int = _ private[this] var input: InternalRow = _ @@ -36,7 +40,7 @@ case class ExpandNode( override def open(): Unit = { child.open() - groups = projections.map(ee => newProjection(ee, child.output)).toArray + groups = projections.map(ee => UnsafeProjection.create(ee, child.output)).toArray idx = groups.length } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index db463029aedf..a6e0e2682dbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -114,7 +114,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray) + val justItems = freqItems.map(m => m.baseMap.keys.toSeq) val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => From 97a81c3ab8a456968ed22e8a4af8e0ad8a505f26 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Dec 2015 01:40:55 +0800 Subject: [PATCH 4/6] Fix remaining failed tests. --- .../spark/sql/catalyst/ScalaReflection.scala | 84 +++++++------------ .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/LocalRelation.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../spark/sql/execution/basicOperators.scala | 4 + 5 files changed, 37 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bd5260890524..c2148d1b844b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -446,68 +446,40 @@ object ScalaReflection extends ScalaReflection { optType match { // For primitive types we must manually unbox the value of the object. case t if t <:< definitions.IntTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "intValue", - IntegerType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) case t if t <:< definitions.LongTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "longValue", - LongType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) case t if t <:< definitions.DoubleTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "doubleValue", - DoubleType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) case t if t <:< definitions.FloatTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "floatValue", - FloatType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) case t if t <:< definitions.ShortTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "shortValue", - ShortType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) case t if t <:< definitions.ByteTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "byteValue", - ByteType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) case t if t <:< definitions.BooleanTpe => - val unwrapped = UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - Invoke( - unwrapped, - "booleanValue", - BooleanType)) + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) // For non-primitives, we can just extract the object from the Option and then recurse. case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 74516dd16c89..717f4b1f4873 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -939,7 +939,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) => val projection = UnsafeProjection.create(projectList, output) LocalRelation(projectList.map(_.toAttribute), - data.map(projection(_).copy().asInstanceOf[UnsafeRow])) + data.map(projection(_).copy())) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 5e56ce10083c..e5960ae908fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -37,8 +37,9 @@ object LocalRelation { def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { val schema = StructType.fromAttributes(output) - val encoder = RowEncoder(schema) - LocalRelation(output, data.map(encoder.toRow(_).copy().asInstanceOf[UnsafeRow])) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val internalRows = data.map(converter(_).asInstanceOf[InternalRow]) + fromInternalRows(output, internalRows) } def fromProduct[T <: Product : ExpressionEncoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 02f244fd9d22..28c2eddc1d55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -606,9 +606,8 @@ class SQLContext private[sql]( val className = beanClass.getName val beanInfo = Introspector.getBeanInfo(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - val projection = UnsafeProjection.create(attrSeq) DataFrame(self, - LocalRelation(attrSeq, rows.toSeq.map(projection(_).copy().asInstanceOf[UnsafeRow]))) + LocalRelation.fromInternalRows(attrSeq, rows.toSeq)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b3e4688557ba..e1c02fea2e84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -162,6 +162,10 @@ case class Limit(limit: Int, child: SparkPlan) override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) protected override def doExecute(): RDD[InternalRow] = { From 97d390f9d4a929f6e115529d8bf8ec096074543c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Dec 2015 11:11:06 +0800 Subject: [PATCH 5/6] Ignore the test for Limit node as it accepts UnsafeRow now. --- .../apache/spark/sql/execution/RowFormatConvertersSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 9c47edbae013..25828dbd89a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -38,7 +38,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) - test("planner should insert unsafe->safe conversions when required") { + ignore("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) From 2500de3ba716ad93dca8001f5fde6c670c898416 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Dec 2015 10:50:12 +0800 Subject: [PATCH 6/6] For comment. --- .../spark/sql/catalyst/plans/logical/LocalRelation.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e5960ae908fd..d4bd3084bd32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} @@ -45,7 +46,7 @@ object LocalRelation { def fromProduct[T <: Product : ExpressionEncoder]( output: Seq[Attribute], data: Seq[T]): LocalRelation = { - val encoder = implicitly[ExpressionEncoder[T]] + val encoder = encoderFor[T] val schema = StructType.fromAttributes(output) new LocalRelation(output, data.map(encoder.toRow(_).copy().asInstanceOf[UnsafeRow])) }