From c549d3e45c641a94300267a7693c082fe2fa6d7b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 6 Sep 2018 23:48:04 +0900 Subject: [PATCH 1/6] Implement InterpretedSafeProjection --- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../InterpretedSafeProjection.scala | 173 ++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 37 ++-- .../expressions/CodeGenerationSuite.scala | 2 +- ...eneratorWithInterpretedFallbackSuite.scala | 15 ++ .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/UnsafeRowConverterSuite.scala | 98 +++++++++- .../DeclarativeAggregateEvaluator.scala | 11 +- .../codegen/GeneratedProjectionSuite.scala | 8 +- .../util/ArrayDataIndexedSeqSuite.scala | 4 +- .../org/apache/spark/sql/types/TestUDT.scala | 61 ++++++ .../spark/sql/FileBasedDataSourceSuite.scala | 4 +- .../spark/sql/UserDefinedTypeSuite.scala | 105 ++++------- .../datasources/json/JsonSuite.scala | 4 +- .../datasources/orc/OrcQuerySuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 2 +- .../execution/ObjectHashAggregateSuite.scala | 4 +- .../sql/sources/HadoopFsRelationTest.scala | 2 +- 18 files changed, 427 insertions(+), 113 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 589e215c55e4..fbf0bd68b958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -302,7 +302,7 @@ case class ExpressionEncoder[T]( private lazy val inputRow = new GenericInternalRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) + private lazy val constructProjection = SafeProjection.create(deserializer :: Nil) /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala new file mode 100644 index 000000000000..884f7e0c8360 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ + + +/** + * An interpreted version of a safe projection. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { + + private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) + + private[this] val exprsWithWriters = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.map { case (e, i) => + val converter = generateSafeValueConverter(e.dataType) + val writer = generateRowWriter(i, e.dataType) + val f = if (!e.nullable) { + (v: Any) => writer(converter(v)) + } else { + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(converter(v)) + } + } + } + (e, f) + } + + private def isPrimitive(dataType: DataType): Boolean = dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + + private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { + case ArrayType(elemType, _) => + if (isPrimitive(elemType)) { + v => { + val arrayValue = v.asInstanceOf[ArrayData] + new GenericArrayData(arrayValue.toArray[Any](elemType)) + } + } else { + val elementConverter = generateSafeValueConverter(elemType) + v => { + val arrayValue = v.asInstanceOf[ArrayData] + val result = new Array[Any](arrayValue.numElements()) + arrayValue.foreach(elemType, (i, e) => { + result(i) = elementConverter(e) + }) + new GenericArrayData(result) + } + } + + case st: StructType => + val fieldTypes = st.fields.map(_.dataType) + val fieldConverters = fieldTypes.map(generateSafeValueConverter) + v => { + val row = v.asInstanceOf[InternalRow] + val ar = new Array[Any](row.numFields) + var idx = 0 + while (idx < row.numFields) { + ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx))) + idx += 1 + } + new GenericInternalRow(ar) + } + + case MapType(keyType, valueType, _) => + lazy val keyConverter = generateSafeValueConverter(keyType) + lazy val valueConverter = generateSafeValueConverter(valueType) + v => { + val mapValue = v.asInstanceOf[MapData] + val keys = mapValue.keyArray().toArray[Any](keyType) + val values = mapValue.valueArray().toArray[Any](valueType) + val convertedKeys = + if (isPrimitive(keyType)) keys else keys.map(keyConverter) + val convertedValues = + if (isPrimitive(valueType)) values else values.map(valueConverter) + + ArrayBasedMapData(convertedKeys, convertedValues) + } + + case udt: UserDefinedType[_] => + generateSafeValueConverter(udt.sqlType) + + case _ => identity + } + + private def generateRowWriter(ordinal: Int, dt: DataType): Any => Unit = dt match { + case BooleanType => + v => mutableRow.setBoolean(ordinal, v.asInstanceOf[Boolean]) + case ByteType => + v => mutableRow.setByte(ordinal, v.asInstanceOf[Byte]) + case ShortType => + v => mutableRow.setShort(ordinal, v.asInstanceOf[Short]) + case IntegerType | DateType => + v => mutableRow.setInt(ordinal, v.asInstanceOf[Int]) + case LongType | TimestampType => + v => mutableRow.setLong(ordinal, v.asInstanceOf[Long]) + case FloatType => + v => mutableRow.setFloat(ordinal, v.asInstanceOf[Float]) + case DoubleType => + v => mutableRow.setDouble(ordinal, v.asInstanceOf[Double]) + case DecimalType.Fixed(precision, _) => + v => mutableRow.setDecimal(ordinal, v.asInstanceOf[Decimal], precision) + case CalendarIntervalType | BinaryType | _: ArrayType | StringType | _: StructType | + _: MapType | _: UserDefinedType[_] => + v => mutableRow.update(ordinal, v) + case NullType => + v => {} + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + override def apply(row: InternalRow): InternalRow = { + var i = 0 + while (i < exprsWithWriters.length) { + val (expr, writer) = exprsWithWriters(i) + writer(expr.eval(row)) + i += 1 + } + mutableRow + } +} + +/** + * Helper functions for creating an [[InterpretedSafeProjection]]. + */ +object InterpretedSafeProjection { + + /** + * Returns an [[SafeProjection]] for given sequence of bound Expressions. + */ + def createProjection(exprs: Seq[Expression]): Projection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedSafeProjection(cleanedExpressions) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 792646cf9f10..65a27a465e40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -166,29 +166,40 @@ object UnsafeProjection } } -/** - * A projection that could turn UnsafeRow into GenericInternalRow - */ -object FromUnsafeProjection { +object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = { + GenerateSafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): Projection = { + InterpretedSafeProjection.createProjection(in) + } /** - * Returns a Projection for given StructType. + * Returns a SafeProjection for given StructType. */ - def apply(schema: StructType): Projection = { - apply(schema.fields.map(_.dataType)) + def create(schema: StructType): Projection = create(schema.fields.map(_.dataType)) + + /** + * Returns a SafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): Projection = { + createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** - * Returns an UnsafeProjection for given Array of DataTypes. + * Returns a SafeProjection for given sequence of Expressions (bounded). */ - def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + def create(exprs: Seq[Expression]): Projection = { + createObject(exprs) } /** - * Returns a Projection for given sequence of Expressions (bounded). + * Returns a SafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. */ - private def create(exprs: Seq[Expression]): Projection = { - GenerateSafeProjection.generate(exprs) + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + create(toBoundExprs(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7843003a4aac..7e6fe5b4e206 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -251,7 +251,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { UTF8String.fromString("c")) assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) - val fromUnsafe = FromUnsafeProjection(schema) + val fromUnsafe = SafeProjection.create(schema) val internalRow2 = fromUnsafe(unsafeRow) assert(internalRow === internalRow2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 6ea3b05ff9c1..da5bddb0c09f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -106,4 +106,19 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) } } + + test("SPARK-25374 Correctly handles NoOp in SafeProjection") { + val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp) + val input = InternalRow.fromSeq(1 :: 1 :: Nil) + val expected = 2 :: null :: Nil + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a7282e1b1cad..b4fd170467d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -321,8 +321,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - actual = FromUnsafeProjection(expression.dataType :: Nil)( - plan(inputRow)).get(0, expression.dataType) + val ref = new BoundReference(0, expression.dataType, nullable = true) + actual = GenerateSafeProjection.generate(ref :: Nil)(plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected, expression)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 268372b5d050..5f522eebbcf9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase with ExpressionEvalHelper { @@ -535,4 +535,100 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } + + testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") { + def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = { + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow = unsafeProj(inputRow) + val safeProj = SafeProjection.create(fields) + safeProj(unsafeRow) + } + + // Simple tests + val inputRow = InternalRow.fromSeq(Seq( + false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), + Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) + )) + val fields1 = Array( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType, + BinaryType) + + assert(convertBackToInternalRow(inputRow, fields1) === inputRow) + + // Array tests + val arrayRow = InternalRow.fromSeq(Seq( + createArray(1, 2, 3), + createArray( + createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*), + createArray(Seq("d").map(UTF8String.fromString): _*)) + )) + val fields2 = Array[DataType]( + ArrayType(IntegerType), + ArrayType(ArrayType(StringType))) + + assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow) + + // Struct tests + val structRow = InternalRow.fromSeq(Seq( + InternalRow.fromSeq(Seq[Any](1, 4.0)), + InternalRow.fromSeq(Seq( + UTF8String.fromString("test"), + InternalRow.fromSeq(Seq( + 1, + createArray(Seq("2", "3").map(UTF8String.fromString): _*) + )) + )) + )) + val fields3 = Array[DataType]( + StructType( + StructField("c0", IntegerType) :: + StructField("c1", DoubleType) :: + Nil), + StructType( + StructField("c2", StringType) :: + StructField("c3", StructType( + StructField("c4", IntegerType) :: + StructField("c5", ArrayType(StringType)) :: + Nil)) :: + Nil)) + + assert(convertBackToInternalRow(structRow, fields3) === structRow) + + // Map tests + val mapRow = InternalRow.fromSeq(Seq( + createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2), + createMap( + createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*), + createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*) + )( + createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort), + createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort) + ))) + val fields4 = Array[DataType]( + MapType(StringType, IntegerType), + MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) + + // Since `ArrayBasedMapData` does not override `equals` and `hashCode`, + // we need to take care of it to compare rows. + def toComparable(d: Any): Any = d match { + case ar: GenericArrayData => + ar.array.map(toComparable).toSeq + case map: ArrayBasedMapData => + val keys = map.keyArray.array.map(toComparable).toSeq + val values = map.valueArray.array.map(toComparable).toSeq + (keys, values) + case o => o + } + val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) + val mapExpectedRow = mapRow.toSeq(fields4) + assert(mapResultRow.map(toComparable) === mapExpectedRow.map(toComparable)) + + // UDT tests + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val udtRow = InternalRow.fromSeq(Seq(udt.serialize(vector))) + val fields5 = Array[DataType](udt) + assert(convertBackToInternalRow(udtRow, fields5) === udtRow) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala index 614f24db0aaf..b0f55b3b5c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -17,25 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, SafeProjection} /** * Evaluator for a [[DeclarativeAggregate]]. */ case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { - lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + lazy val initializer = SafeProjection.create(function.initialValues) - lazy val updater = GenerateSafeProjection.generate( + lazy val updater = SafeProjection.create( function.updateExpressions, function.aggBufferAttributes ++ input) - lazy val merger = GenerateSafeProjection.generate( + lazy val merger = SafeProjection.create( function.mergeExpressions, function.aggBufferAttributes ++ function.inputAggBufferAttributes) - lazy val evaluator = GenerateSafeProjection.generate( + lazy val evaluator = SafeProjection.create( function.evaluateExpression :: Nil, function.aggBufferAttributes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 2c45b3b0c73d..4c9bcfe8f93a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -58,7 +58,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -109,7 +109,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -147,7 +147,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafeRow.getArray(1).getBinary(1) === null) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) - val safeProj = FromUnsafeProjection(fields) + val safeProj = SafeProjection.create(fields) val row2 = safeProj(unsafeRow) assert(row2 === row) } @@ -233,7 +233,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(nested) // test generated MutableProjection diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index 6400898343ae..da71e3a4d53e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{SafeProjection, UnsafeArrayData, UnsafeProjection} import org.apache.spark.sql.types._ class ArrayDataIndexedSeqSuite extends SparkFunSuite { @@ -77,7 +77,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { val internalRow = rowConverter.toRow(row) val unsafeRowConverter = UnsafeProjection.create(schema) - val safeRowConverter = FromUnsafeProjection(schema) + val safeRowConverter = SafeProjection.create(schema) val unsafeRow = unsafeRowConverter(internalRow) val safeRow = safeRowConverter(unsafeRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala new file mode 100644 index 000000000000..1be8ee9dfa92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} + + +// Wrapped in an object to check Scala compatibility. See SPARK-13929 +object TestUDT { + + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) + private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def hashCode(): Int = java.util.Arrays.hashCode(data) + + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) + case _ => false + } + + override def toString: String = data.mkString("(", ", ", ")") + } + + private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) + } + } + + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 64b42c32b8b1..54299e9808bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -312,13 +312,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + val schema = StructType(StructField("a", new TestUDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cf956316057e..6628d36ffc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,56 +20,14 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { +private[sql] case class MyLabeledPoint(label: Double, features: TestUDT.MyDenseVector) { def getLabel: Double = label - def getFeatures: UDT.MyDenseVector = features -} - -// Wrapped in an object to check Scala compatibility. See SPARK-13929 -object UDT { - - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) - private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def hashCode(): Int = java.util.Arrays.hashCode(data) - - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) - case _ => false - } - - override def toString: String = data.mkString("(", ", ", ")") - } - - private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(features: MyDenseVector): ArrayData = { - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } - - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: ArrayData => - new MyDenseVector(data.toDoubleArray()) - } - } - - override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] - - private[spark] override def asNullable: MyDenseVectorUDT = this - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] - } - + def getFeatures: TestUDT.MyDenseVector = features } // object and classes to test SPARK-19311 @@ -148,12 +106,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT import testImplicits._ private lazy val pointsRDD = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0)))).toDF() private lazy val pointsRDD2 = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -162,16 +120,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[UDT.MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: UDT.MyDenseVector) => v } - val featuresArrays: Array[UDT.MyDenseVector] = features.collect() + val features: RDD[TestUDT.MyDenseVector] = + pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } + val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.2, 2.0)))) } test("UDTs and UDFs") { - spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", + (d: TestUDT.MyDenseVector) => d.isInstanceOf[TestUDT.MyDenseVector]) pointsRDD.createOrReplaceTempView("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -185,8 +144,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } @@ -197,17 +156,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val vec = new UDT.MyDenseVector(Array(0.1, 1.0)) + val vec = new TestUDT.MyDenseVector(Array(0.1, 1.0)) val df = Seq((1, vec)).toDF("int", "vec") - assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1)) - assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1)) + assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) + assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) } @@ -219,14 +178,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT ) val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonRDD = spark.read.schema(schema).json(data.toDS()) checkAnswer( jsonRDD, - Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: - Row(2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Row(1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: Nil ) } @@ -239,25 +198,25 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonDataset = spark.read.schema(schema).json(data.toDS()) - .as[(Int, UDT.MyDenseVector)] + .as[(Int, TestUDT.MyDenseVector)] checkDataset( jsonDataset, - (1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), - (2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) + (1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), + (2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) ) } test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") - assert(new UDT.MyDenseVectorUDT().typeName === "mydensevector") + assert(new TestUDT.MyDenseVectorUDT().typeName === "mydensevector") } test("Catalyst type converter null handling for UDTs") { - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) assert(toScalaConverter(null) === null) @@ -303,12 +262,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("except on UDT") { checkAnswer( pointsRDD.except(pointsRDD2), - Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Seq(Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } test("SPARK-23054 Cast UserDefinedType to string") { - val udt = new UDT.MyDenseVectorUDT() - val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) val data = udt.serialize(vector) val ret = Cast(Literal(data, udt), StringType, None) checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9d23161c1f24..dff37ca2d40f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1463,7 +1463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) } @@ -1487,7 +1487,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Seq(2, 3, 4), Map("a string" -> 2000L), Row(4.75.toFloat, Seq(false, true)), - new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) + new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 998b7b31dcd6..918dbcdfa1cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, TestUDT} import org.apache.spark.util.Utils case class AllDataTypesWithNonPrimitiveType( @@ -103,7 +103,7 @@ abstract class OrcQueryTest extends OrcTest { test("Read/write UserDefinedType") { withTempPath { path => - val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) + val data = Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) val udtDF = data.toDF("id", "vectors") udtDF.write.orc(path.getAbsolutePath) val readBack = spark.read.schema(udtDF.schema).orc(path.getAbsolutePath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index c65bf7c14c7a..cfae2d82e273 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -884,7 +884,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) // Right now, we will use SortAggregate to handle UDAFs. // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortAggregate to use // UnsafeRow as the aggregation buffer. While, dataTypes will trigger diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index c9309197791b..2391106cfb25 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -124,7 +124,7 @@ class ObjectHashAggregateSuite .add("f2", ArrayType(BooleanType), nullable = true), // UDT - new UDT.MyDenseVectorUDT(), + new TestUDT.MyDenseVectorUDT(), // Others StringType, @@ -259,7 +259,7 @@ class ObjectHashAggregateSuite StringType, BinaryType, NullType, BooleanType ) - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val fixedLengthTypes = builtinNumericTypes ++ Seq(BooleanType, NullType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 6bd59fde550d..6075f2c8877d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -115,7 +115,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new StructType() .add("f1", FloatType, nullable = true) .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), - new UDT.MyDenseVectorUDT() + new TestUDT.MyDenseVectorUDT() ).filter(supportsDataType) test(s"test all data types") { From 127fd2718667f53369c22107a2df8c8504a8d76e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 23 Oct 2018 11:12:47 +0900 Subject: [PATCH 2/6] Fix --- .../InterpretedSafeProjection.scala | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala index 884f7e0c8360..b3d9007d9068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -53,34 +53,16 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection (e, f) } - private def isPrimitive(dataType: DataType): Boolean = dataType match { - case BooleanType => true - case ByteType => true - case ShortType => true - case IntegerType => true - case LongType => true - case FloatType => true - case DoubleType => true - case _ => false - } - private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { case ArrayType(elemType, _) => - if (isPrimitive(elemType)) { - v => { - val arrayValue = v.asInstanceOf[ArrayData] - new GenericArrayData(arrayValue.toArray[Any](elemType)) - } - } else { - val elementConverter = generateSafeValueConverter(elemType) - v => { - val arrayValue = v.asInstanceOf[ArrayData] - val result = new Array[Any](arrayValue.numElements()) - arrayValue.foreach(elemType, (i, e) => { - result(i) = elementConverter(e) - }) - new GenericArrayData(result) - } + val elementConverter = generateSafeValueConverter(elemType) + v => { + val arrayValue = v.asInstanceOf[ArrayData] + val result = new Array[Any](arrayValue.numElements()) + arrayValue.foreach(elemType, (i, e) => { + result(i) = elementConverter(e) + }) + new GenericArrayData(result) } case st: StructType => @@ -104,11 +86,8 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection val mapValue = v.asInstanceOf[MapData] val keys = mapValue.keyArray().toArray[Any](keyType) val values = mapValue.valueArray().toArray[Any](valueType) - val convertedKeys = - if (isPrimitive(keyType)) keys else keys.map(keyConverter) - val convertedValues = - if (isPrimitive(valueType)) values else values.map(valueConverter) - + val convertedKeys = keys.map(keyConverter) + val convertedValues = values.map(valueConverter) ArrayBasedMapData(convertedKeys, convertedValues) } From cec84806c49ff7e616fc11e8703f0880807f41d9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 24 Oct 2018 14:18:00 +0900 Subject: [PATCH 3/6] Fix --- .../InterpretedSafeProjection.scala | 33 ++----------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala index b3d9007d9068..70789dac1d87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -38,15 +37,15 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection case _ => true }.map { case (e, i) => val converter = generateSafeValueConverter(e.dataType) - val writer = generateRowWriter(i, e.dataType) + val writer = InternalRow.getWriter(i, e.dataType) val f = if (!e.nullable) { - (v: Any) => writer(converter(v)) + (v: Any) => writer(mutableRow, converter(v)) } else { (v: Any) => { if (v == null) { mutableRow.setNullAt(i) } else { - writer(converter(v)) + writer(mutableRow, converter(v)) } } } @@ -97,32 +96,6 @@ class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection case _ => identity } - private def generateRowWriter(ordinal: Int, dt: DataType): Any => Unit = dt match { - case BooleanType => - v => mutableRow.setBoolean(ordinal, v.asInstanceOf[Boolean]) - case ByteType => - v => mutableRow.setByte(ordinal, v.asInstanceOf[Byte]) - case ShortType => - v => mutableRow.setShort(ordinal, v.asInstanceOf[Short]) - case IntegerType | DateType => - v => mutableRow.setInt(ordinal, v.asInstanceOf[Int]) - case LongType | TimestampType => - v => mutableRow.setLong(ordinal, v.asInstanceOf[Long]) - case FloatType => - v => mutableRow.setFloat(ordinal, v.asInstanceOf[Float]) - case DoubleType => - v => mutableRow.setDouble(ordinal, v.asInstanceOf[Double]) - case DecimalType.Fixed(precision, _) => - v => mutableRow.setDecimal(ordinal, v.asInstanceOf[Decimal], precision) - case CalendarIntervalType | BinaryType | _: ArrayType | StringType | _: StructType | - _: MapType | _: UserDefinedType[_] => - v => mutableRow.update(ordinal, v) - case NullType => - v => {} - case _ => - throw new SparkException(s"Unsupported data type $dt") - } - override def apply(row: InternalRow): InternalRow = { var i = 0 while (i < exprsWithWriters.length) { From 0b23adb78f6b0553829e11f113a897856d5afa41 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 09:53:24 +0900 Subject: [PATCH 4/6] Fix --- .../spark/sql/catalyst/expressions/MutableProjectionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 2db1c3b98819..0d594eb10962 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -51,7 +51,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) val proj = createMutableProjection(fixedLengthTypes) val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) - assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow) + assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } testBothCodegenAndInterpreted("variable-length types") { From 7ef5f866eb02f6638a5be00a602de6c6810ae2a3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 13:55:13 +0900 Subject: [PATCH 5/6] Fix --- .../spark/sql/catalyst/expressions/Projection.scala | 3 +++ .../catalyst/expressions/UnsafeRowConverterSuite.scala | 8 +++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 65a27a465e40..b48f7ba655b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -166,6 +166,9 @@ object UnsafeProjection } } +/** + * A projection that could turn UnsafeRow into GenericInternalRow + */ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] { override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 5f522eebbcf9..757216d99ae6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -609,19 +609,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB MapType(StringType, IntegerType), MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) + val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) + val mapExpectedRow = mapRow.toSeq(fields4) // Since `ArrayBasedMapData` does not override `equals` and `hashCode`, - // we need to take care of it to compare rows. + // we convert it into the two `Seq`s of keys and values for correct comparisons. def toComparable(d: Any): Any = d match { - case ar: GenericArrayData => - ar.array.map(toComparable).toSeq case map: ArrayBasedMapData => val keys = map.keyArray.array.map(toComparable).toSeq val values = map.valueArray.array.map(toComparable).toSeq (keys, values) case o => o } - val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) - val mapExpectedRow = mapRow.toSeq(fields4) assert(mapResultRow.map(toComparable) === mapExpectedRow.map(toComparable)) // UDT tests From fbfbbff55d900ae1101ceb4f7823a9298464cb07 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 15:37:22 +0900 Subject: [PATCH 6/6] Fix --- .../expressions/UnsafeRowConverterSuite.scala | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 757216d99ae6..ecb8047459b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -609,18 +609,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB MapType(StringType, IntegerType), MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) - val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) - val mapExpectedRow = mapRow.toSeq(fields4) - // Since `ArrayBasedMapData` does not override `equals` and `hashCode`, - // we convert it into the two `Seq`s of keys and values for correct comparisons. - def toComparable(d: Any): Any = d match { - case map: ArrayBasedMapData => - val keys = map.keyArray.array.map(toComparable).toSeq - val values = map.valueArray.array.map(toComparable).toSeq - (keys, values) - case o => o - } - assert(mapResultRow.map(toComparable) === mapExpectedRow.map(toComparable)) + val mapResultRow = convertBackToInternalRow(mapRow, fields4) + val mapExpectedRow = mapRow + checkResult(mapExpectedRow, mapResultRow, + exprDataType = StructType(fields4.zipWithIndex.map(f => StructField(s"c${f._2}", f._1))), + exprNullable = false) // UDT tests val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0))