diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 589e215c55e4..fbf0bd68b958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -302,7 +302,7 @@ case class ExpressionEncoder[T]( private lazy val inputRow = new GenericInternalRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) + private lazy val constructProjection = SafeProjection.create(deserializer :: Nil) /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala new file mode 100644 index 000000000000..70789dac1d87 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ + + +/** + * An interpreted version of a safe projection. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedSafeProjection(expressions: Seq[Expression]) extends Projection { + + private[this] val mutableRow = new SpecificInternalRow(expressions.map(_.dataType)) + + private[this] val exprsWithWriters = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.map { case (e, i) => + val converter = generateSafeValueConverter(e.dataType) + val writer = InternalRow.getWriter(i, e.dataType) + val f = if (!e.nullable) { + (v: Any) => writer(mutableRow, converter(v)) + } else { + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(mutableRow, converter(v)) + } + } + } + (e, f) + } + + private def generateSafeValueConverter(dt: DataType): Any => Any = dt match { + case ArrayType(elemType, _) => + val elementConverter = generateSafeValueConverter(elemType) + v => { + val arrayValue = v.asInstanceOf[ArrayData] + val result = new Array[Any](arrayValue.numElements()) + arrayValue.foreach(elemType, (i, e) => { + result(i) = elementConverter(e) + }) + new GenericArrayData(result) + } + + case st: StructType => + val fieldTypes = st.fields.map(_.dataType) + val fieldConverters = fieldTypes.map(generateSafeValueConverter) + v => { + val row = v.asInstanceOf[InternalRow] + val ar = new Array[Any](row.numFields) + var idx = 0 + while (idx < row.numFields) { + ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx))) + idx += 1 + } + new GenericInternalRow(ar) + } + + case MapType(keyType, valueType, _) => + lazy val keyConverter = generateSafeValueConverter(keyType) + lazy val valueConverter = generateSafeValueConverter(valueType) + v => { + val mapValue = v.asInstanceOf[MapData] + val keys = mapValue.keyArray().toArray[Any](keyType) + val values = mapValue.valueArray().toArray[Any](valueType) + val convertedKeys = keys.map(keyConverter) + val convertedValues = values.map(valueConverter) + ArrayBasedMapData(convertedKeys, convertedValues) + } + + case udt: UserDefinedType[_] => + generateSafeValueConverter(udt.sqlType) + + case _ => identity + } + + override def apply(row: InternalRow): InternalRow = { + var i = 0 + while (i < exprsWithWriters.length) { + val (expr, writer) = exprsWithWriters(i) + writer(expr.eval(row)) + i += 1 + } + mutableRow + } +} + +/** + * Helper functions for creating an [[InterpretedSafeProjection]]. + */ +object InterpretedSafeProjection { + + /** + * Returns an [[SafeProjection]] for given sequence of bound Expressions. + */ + def createProjection(exprs: Seq[Expression]): Projection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedSafeProjection(cleanedExpressions) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 792646cf9f10..b48f7ba655b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -169,26 +169,40 @@ object UnsafeProjection /** * A projection that could turn UnsafeRow into GenericInternalRow */ -object FromUnsafeProjection { +object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = { + GenerateSafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): Projection = { + InterpretedSafeProjection.createProjection(in) + } /** - * Returns a Projection for given StructType. + * Returns a SafeProjection for given StructType. */ - def apply(schema: StructType): Projection = { - apply(schema.fields.map(_.dataType)) + def create(schema: StructType): Projection = create(schema.fields.map(_.dataType)) + + /** + * Returns a SafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): Projection = { + createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** - * Returns an UnsafeProjection for given Array of DataTypes. + * Returns a SafeProjection for given sequence of Expressions (bounded). */ - def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + def create(exprs: Seq[Expression]): Projection = { + createObject(exprs) } /** - * Returns a Projection for given sequence of Expressions (bounded). + * Returns a SafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. */ - private def create(exprs: Seq[Expression]): Projection = { - GenerateSafeProjection.generate(exprs) + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + create(toBoundExprs(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7843003a4aac..7e6fe5b4e206 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -251,7 +251,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { UTF8String.fromString("c")) assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3) - val fromUnsafe = FromUnsafeProjection(schema) + val fromUnsafe = SafeProjection.create(schema) val internalRow2 = fromUnsafe(unsafeRow) assert(internalRow === internalRow2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 6ea3b05ff9c1..da5bddb0c09f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -106,4 +106,19 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) } } + + test("SPARK-25374 Correctly handles NoOp in SafeProjection") { + val exprs = Seq(Add(BoundReference(0, IntegerType, nullable = true), Literal.create(1)), NoOp) + val input = InternalRow.fromSeq(1 :: 1 :: Nil) + val expected = 2 :: null :: Nil + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val proj = SafeProjection.createObject(exprs) + assert(proj(input).toSeq(StructType.fromDDL("c0 int, c1 int")) === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a7282e1b1cad..b4fd170467d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -321,8 +321,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - actual = FromUnsafeProjection(expression.dataType :: Nil)( - plan(inputRow)).get(0, expression.dataType) + val ref = new BoundReference(0, expression.dataType, nullable = true) + actual = GenerateSafeProjection.generate(ref :: Nil)(plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected, expression)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 2db1c3b98819..0d594eb10962 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -51,7 +51,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) val proj = createMutableProjection(fixedLengthTypes) val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) - assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow) + assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } testBothCodegenAndInterpreted("variable-length types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 268372b5d050..ecb8047459b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase with ExpressionEvalHelper { @@ -535,4 +535,91 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } + + testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") { + def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = { + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow = unsafeProj(inputRow) + val safeProj = SafeProjection.create(fields) + safeProj(unsafeRow) + } + + // Simple tests + val inputRow = InternalRow.fromSeq(Seq( + false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), + Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) + )) + val fields1 = Array( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType, + BinaryType) + + assert(convertBackToInternalRow(inputRow, fields1) === inputRow) + + // Array tests + val arrayRow = InternalRow.fromSeq(Seq( + createArray(1, 2, 3), + createArray( + createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*), + createArray(Seq("d").map(UTF8String.fromString): _*)) + )) + val fields2 = Array[DataType]( + ArrayType(IntegerType), + ArrayType(ArrayType(StringType))) + + assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow) + + // Struct tests + val structRow = InternalRow.fromSeq(Seq( + InternalRow.fromSeq(Seq[Any](1, 4.0)), + InternalRow.fromSeq(Seq( + UTF8String.fromString("test"), + InternalRow.fromSeq(Seq( + 1, + createArray(Seq("2", "3").map(UTF8String.fromString): _*) + )) + )) + )) + val fields3 = Array[DataType]( + StructType( + StructField("c0", IntegerType) :: + StructField("c1", DoubleType) :: + Nil), + StructType( + StructField("c2", StringType) :: + StructField("c3", StructType( + StructField("c4", IntegerType) :: + StructField("c5", ArrayType(StringType)) :: + Nil)) :: + Nil)) + + assert(convertBackToInternalRow(structRow, fields3) === structRow) + + // Map tests + val mapRow = InternalRow.fromSeq(Seq( + createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2), + createMap( + createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*), + createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*) + )( + createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort), + createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort) + ))) + val fields4 = Array[DataType]( + MapType(StringType, IntegerType), + MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) + + val mapResultRow = convertBackToInternalRow(mapRow, fields4) + val mapExpectedRow = mapRow + checkResult(mapExpectedRow, mapResultRow, + exprDataType = StructType(fields4.zipWithIndex.map(f => StructField(s"c${f._2}", f._1))), + exprNullable = false) + + // UDT tests + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val udtRow = InternalRow.fromSeq(Seq(udt.serialize(vector))) + val fields5 = Array[DataType](udt) + assert(convertBackToInternalRow(udtRow, fields5) === udtRow) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala index 614f24db0aaf..b0f55b3b5c44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -17,25 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow, SafeProjection} /** * Evaluator for a [[DeclarativeAggregate]]. */ case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { - lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + lazy val initializer = SafeProjection.create(function.initialValues) - lazy val updater = GenerateSafeProjection.generate( + lazy val updater = SafeProjection.create( function.updateExpressions, function.aggBufferAttributes ++ input) - lazy val merger = GenerateSafeProjection.generate( + lazy val merger = SafeProjection.create( function.mergeExpressions, function.aggBufferAttributes ++ function.inputAggBufferAttributes) - lazy val evaluator = GenerateSafeProjection.generate( + lazy val evaluator = SafeProjection.create( function.evaluateExpression :: Nil, function.aggBufferAttributes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 2c45b3b0c73d..4c9bcfe8f93a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -58,7 +58,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -109,7 +109,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { } // test generated SafeProjection - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(unsafe) // Can't compare GenericInternalRow with JoinedRow directly (0 until N).foreach { i => @@ -147,7 +147,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafeRow.getArray(1).getBinary(1) === null) assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) - val safeProj = FromUnsafeProjection(fields) + val safeProj = SafeProjection.create(fields) val row2 = safeProj(unsafeRow) assert(row2 === row) } @@ -233,7 +233,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val nestedSchema = StructType( Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) - val safeProj = FromUnsafeProjection(nestedSchema) + val safeProj = SafeProjection.create(nestedSchema) val result = safeProj(nested) // test generated MutableProjection diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala index 6400898343ae..da71e3a4d53e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{SafeProjection, UnsafeArrayData, UnsafeProjection} import org.apache.spark.sql.types._ class ArrayDataIndexedSeqSuite extends SparkFunSuite { @@ -77,7 +77,7 @@ class ArrayDataIndexedSeqSuite extends SparkFunSuite { val internalRow = rowConverter.toRow(row) val unsafeRowConverter = UnsafeProjection.create(schema) - val safeRowConverter = FromUnsafeProjection(schema) + val safeRowConverter = SafeProjection.create(schema) val unsafeRow = unsafeRowConverter(internalRow) val safeRow = safeRowConverter(unsafeRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala new file mode 100644 index 000000000000..1be8ee9dfa92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/TestUDT.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} + + +// Wrapped in an object to check Scala compatibility. See SPARK-13929 +object TestUDT { + + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) + private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def hashCode(): Int = java.util.Arrays.hashCode(data) + + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) + case _ => false + } + + override def toString: String = data.mkString("(", ", ", ")") + } + + private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) + } + } + + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + + private[spark] override def asNullable: MyDenseVectorUDT = this + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 64b42c32b8b1..54299e9808bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -312,13 +312,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + val schema = StructType(StructField("a", new TestUDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cf956316057e..6628d36ffc70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,56 +20,14 @@ package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { +private[sql] case class MyLabeledPoint(label: Double, features: TestUDT.MyDenseVector) { def getLabel: Double = label - def getFeatures: UDT.MyDenseVector = features -} - -// Wrapped in an object to check Scala compatibility. See SPARK-13929 -object UDT { - - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) - private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def hashCode(): Int = java.util.Arrays.hashCode(data) - - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) - case _ => false - } - - override def toString: String = data.mkString("(", ", ", ")") - } - - private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - - override def serialize(features: MyDenseVector): ArrayData = { - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } - - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: ArrayData => - new MyDenseVector(data.toDoubleArray()) - } - } - - override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] - - private[spark] override def asNullable: MyDenseVectorUDT = this - - override def hashCode(): Int = getClass.hashCode() - - override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] - } - + def getFeatures: TestUDT.MyDenseVector = features } // object and classes to test SPARK-19311 @@ -148,12 +106,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT import testImplicits._ private lazy val pointsRDD = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0)))).toDF() private lazy val pointsRDD2 = Seq( - MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF() + MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -162,16 +120,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[UDT.MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: UDT.MyDenseVector) => v } - val featuresArrays: Array[UDT.MyDenseVector] = features.collect() + val features: RDD[TestUDT.MyDenseVector] = + pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } + val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.2, 2.0)))) } test("UDTs and UDFs") { - spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + spark.udf.register("testType", + (d: TestUDT.MyDenseVector) => d.isInstanceOf[TestUDT.MyDenseVector]) pointsRDD.createOrReplaceTempView("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -185,8 +144,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } @@ -197,17 +156,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( spark.read.parquet(path), Seq( - Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } } // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val vec = new UDT.MyDenseVector(Array(0.1, 1.0)) + val vec = new TestUDT.MyDenseVector(Array(0.1, 1.0)) val df = Seq((1, vec)).toDF("int", "vec") - assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1)) - assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1)) + assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) + assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) } @@ -219,14 +178,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT ) val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonRDD = spark.read.schema(schema).json(data.toDS()) checkAnswer( jsonRDD, - Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: - Row(2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Row(1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: Nil ) } @@ -239,25 +198,25 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new UDT.MyDenseVectorUDT, false) + StructField("vec", new TestUDT.MyDenseVectorUDT, false) )) val jsonDataset = spark.read.schema(schema).json(data.toDS()) - .as[(Int, UDT.MyDenseVector)] + .as[(Int, TestUDT.MyDenseVector)] checkDataset( jsonDataset, - (1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), - (2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) + (1, new TestUDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), + (2, new TestUDT.MyDenseVector(Array(2.25, 4.5, 8.75))) ) } test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") - assert(new UDT.MyDenseVectorUDT().typeName === "mydensevector") + assert(new TestUDT.MyDenseVectorUDT().typeName === "mydensevector") } test("Catalyst type converter null handling for UDTs") { - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) assert(toScalaConverter(null) === null) @@ -303,12 +262,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("except on UDT") { checkAnswer( pointsRDD.except(pointsRDD2), - Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) + Seq(Row(0.0, new TestUDT.MyDenseVector(Array(0.2, 2.0))))) } test("SPARK-23054 Cast UserDefinedType to string") { - val udt = new UDT.MyDenseVectorUDT() - val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val udt = new TestUDT.MyDenseVectorUDT() + val vector = new TestUDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) val data = udt.serialize(vector) val ret = Cast(Literal(data, udt), StringType, None) checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9d23161c1f24..dff37ca2d40f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1463,7 +1463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) } @@ -1487,7 +1487,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Seq(2, 3, 4), Map("a string" -> 2000L), Row(4.75.toFloat, Seq(false, true)), - new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) + new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 998b7b31dcd6..918dbcdfa1cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType, TestUDT} import org.apache.spark.util.Utils case class AllDataTypesWithNonPrimitiveType( @@ -103,7 +103,7 @@ abstract class OrcQueryTest extends OrcTest { test("Read/write UserDefinedType") { withTempPath { path => - val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) + val data = Seq((1, new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))) val udtDF = data.toDF("id", "vectors") udtDF.write.orc(path.getAbsolutePath) val readBack = spark.read.schema(udtDF.schema).orc(path.getAbsolutePath) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index c65bf7c14c7a..cfae2d82e273 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -884,7 +884,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new UDT.MyDenseVectorUDT()) + new TestUDT.MyDenseVectorUDT()) // Right now, we will use SortAggregate to handle UDAFs. // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortAggregate to use // UnsafeRow as the aggregation buffer. While, dataTypes will trigger diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index c9309197791b..2391106cfb25 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -124,7 +124,7 @@ class ObjectHashAggregateSuite .add("f2", ArrayType(BooleanType), nullable = true), // UDT - new UDT.MyDenseVectorUDT(), + new TestUDT.MyDenseVectorUDT(), // Others StringType, @@ -259,7 +259,7 @@ class ObjectHashAggregateSuite StringType, BinaryType, NullType, BooleanType ) - val udt = new UDT.MyDenseVectorUDT() + val udt = new TestUDT.MyDenseVectorUDT() val fixedLengthTypes = builtinNumericTypes ++ Seq(BooleanType, NullType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 6bd59fde550d..6075f2c8877d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -115,7 +115,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new StructType() .add("f1", FloatType, nullable = true) .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), - new UDT.MyDenseVectorUDT() + new TestUDT.MyDenseVectorUDT() ).filter(supportsDataType) test(s"test all data types") {