From f26c8dcf9265db6894484c0eb8255cf6c4683c9e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 14:36:28 +0800 Subject: [PATCH 01/11] Fix bug of Python-only UDTs when MapObjects works on it. --- python/pyspark/sql/tests.py | 6 ++++++ .../spark/sql/catalyst/expressions/objects/objects.scala | 8 ++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c631ad8a4618d..3147a3a76a6a0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -558,6 +558,12 @@ def check_datatype(datatype): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema = schema) + df.show() + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a709445..e60d78c0affac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -427,8 +427,12 @@ case class MapObjects private( case _ => "" } + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } - val (getLength, getLoopVar) = inputData.dataType match { + val (getLength, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -442,7 +446,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputData.dataType match { + val loopNullCheck = inputDataType match { case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => From cd80f0e552328d4fc3e24131bf113de9baf36032 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 15:14:03 +0800 Subject: [PATCH 02/11] Fix python style. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3147a3a76a6a0..6de132716e81b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -561,7 +561,7 @@ def check_datatype(datatype): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema = schema) + schema=schema) df.show() def test_infer_schema_with_udt(self): From d22dca85f53d60c3b15c37c45d3e85171beeea85 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 15:38:05 +0800 Subject: [PATCH 03/11] Fix the bug of Python-only UDTs when it is used as the element of ArrayType. --- python/pyspark/sql/tests.py | 6 ++++++ .../apache/spark/sql/catalyst/encoders/RowEncoder.scala | 8 +++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6de132716e81b..a51570118fed4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -564,6 +564,12 @@ def check_datatype(datatype): schema=schema) df.show() + schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.show() + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b551a..652dfb652cae6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,9 +220,15 @@ object RowEncoder { CreateExternalRow(fields, schema) } - private def deserializerFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { case dt if ScalaReflection.isNativeType(dt) => input + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) + case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) val udtClass: Class[_] = if (annotation != null) { From fc9c1069d86b1f0907ff29429f3448d1d5356ffa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 21 Jun 2016 15:42:49 +0800 Subject: [PATCH 04/11] Fix another issue. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 652dfb652cae6..2a6fcd03a26b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -206,6 +206,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) case udt: UserDefinedType[_] => ObjectType(udt.userClass) } From d603cc2000fb48d8db5a18445766795ebda4429c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Jun 2016 10:51:13 +0800 Subject: [PATCH 05/11] Create new unit tests. --- python/pyspark/sql/tests.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a51570118fed4..251828883a371 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -558,17 +558,40 @@ def check_datatype(datatype): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.show() + def test_nested_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) df = self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.collect() + + schema = StructType().add("key", LongType()).add("val", + MapType(LongType(), PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) + df.collect() + + def test_complex_nested_udt_in_df(self): + from pyspark.sql.functions import udf + + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) - df.show() + df.collect() + + gd = df.groupby("key").agg({"val": "collect_list"}) + gd.collect() + udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + gd.select(udf(*gd)).collect() def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT From a0b81ba448018200eb8947bc496d8f07d87a64b8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Jun 2016 11:22:14 +0800 Subject: [PATCH 06/11] Fix python style. --- python/pyspark/sql/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 251828883a371..ae5b1e7330d68 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,15 +568,15 @@ def test_simple_udt_in_df(self): def test_nested_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) df = self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], - schema=schema) + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) df.collect() schema = StructType().add("key", LongType()).add("val", MapType(LongType(), PythonOnlyUDT())) df = self.spark.createDataFrame( - [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], - schema=schema) + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) df.collect() def test_complex_nested_udt_in_df(self): From 4c00bb1843cb9ad1ad6977188fd3f6e72f429730 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Jun 2016 23:13:35 +0800 Subject: [PATCH 07/11] Avoid exposing Python udt to MapObjects. --- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../catalyst/expressions/objects/objects.scala | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 2a6fcd03a26b0..c77a181ecaf20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -269,7 +269,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(deserializerFor(_), input, et), + MapObjects(deserializerFor(_), input, et, Some(dataType)), "array", ObjectType(classOf[Array[_]])) StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e60d78c0affac..f5510810ba7f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -349,11 +349,12 @@ object MapObjects { def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { + elementType: DataType, + inputDataType: Option[DataType] = None): MapObjects = { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopVar, function(loopVar), inputData) + MapObjects(loopVar, function(loopVar), inputData, inputDataType) } } @@ -370,11 +371,13 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param inputDataType The dataType of inputData. Optional. */ case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + inputDataType: Option[DataType]) extends Expression with NonSQLExpression { override def nullable: Boolean = true @@ -427,12 +430,9 @@ case class MapObjects private( case _ => "" } - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } + val inputDT = inputDataType.getOrElse(inputData.dataType) - val (getLength, getLoopVar) = inputDataType match { + val (getLength, getLoopVar) = inputDT match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -446,7 +446,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputDataType match { + val loopNullCheck = inputDT match { case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => From 65a33b05eaeef8454f8746313075163e21f73c8f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Jun 2016 11:09:11 +0800 Subject: [PATCH 08/11] Address comments. --- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 35 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index c77a181ecaf20..ec154a6d85d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -269,7 +269,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(deserializerFor(_), input, et, Some(dataType)), + MapObjects(deserializerFor(_), input, et, dataType), "array", ObjectType(classOf[Array[_]])) StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 07def1f2a33f4..6104dcb438e82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -346,11 +346,34 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + */ + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType): MapObjects = { + apply(function, inputData, elementType, inputData.dataType) + } + + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + * @param inputDataType The explicitly given data type of inputData to override the + * data type inferred from inputData (i.e., inputData.dataType). + */ def apply( function: Expression => Expression, inputData: Expression, elementType: DataType, - inputDataType: Option[DataType] = None): MapObjects = { + inputDataType: DataType): MapObjects = { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) @@ -375,7 +398,7 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. - * @param inputDataType The dataType of inputData. Optional. + * @param inputDataType The dataType of inputData. */ case class MapObjects private( loopValue: String, @@ -383,7 +406,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - inputDataType: Option[DataType]) extends Expression with NonSQLExpression { + inputDataType: DataType) extends Expression with NonSQLExpression { override def nullable: Boolean = true @@ -436,9 +459,7 @@ case class MapObjects private( case _ => "" } - val inputDT = inputDataType.getOrElse(inputData.dataType) - - val (getLength, getLoopVar) = inputDT match { + val (getLength, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -452,7 +473,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputDT match { + val loopNullCheck = inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => From 1b751affe091fb0c946cb5f97c75ae4c1f4e637e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jul 2016 17:18:52 +0800 Subject: [PATCH 09/11] Fix test. --- .../expressions/objects/objects.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6104dcb438e82..2e1ff9a85cdb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -357,7 +357,10 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType): MapObjects = { - apply(function, inputData, elementType, inputData.dataType) + val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, None) } /** @@ -368,6 +371,11 @@ object MapObjects { * @param elementType The data type of elements in the collection. * @param inputDataType The explicitly given data type of inputData to override the * data type inferred from inputData (i.e., inputData.dataType). + * When Python UDT whose sqlType is an array, the deserializer + * expression will apply MapObjects on it. However, as the data type + * of inputData is Python UDT, which is not an expected array type + * in MapObjects. In this case, we need to explicitly use + * Python UDT's sqlType as data type. */ def apply( function: Expression => Expression, @@ -377,7 +385,8 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, inputDataType) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + Some(inputDataType)) } } @@ -406,7 +415,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - inputDataType: DataType) extends Expression with NonSQLExpression { + inputDataType: Option[DataType]) extends Expression with NonSQLExpression { override def nullable: Boolean = true @@ -459,7 +468,7 @@ case class MapObjects private( case _ => "" } - val (getLength, getLoopVar) = inputDataType match { + val (getLength, getLoopVar) = inputDataType.getOrElse(inputData.dataType) match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -473,7 +482,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputDataType match { + val loopNullCheck = inputDataType.getOrElse(inputData.dataType) match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => From 87a0953ec36d6beacb4665a94da834d0a4615baa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 7 Jul 2016 15:38:13 +0800 Subject: [PATCH 10/11] Address comment. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2e1ff9a85cdb1..ec7f3b62fd01d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -407,7 +407,8 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. - * @param inputDataType The dataType of inputData. + * @param inputDataType The optional dataType of inputData. If it is None, the default behavior is + * to use the resolved data type of the inputData. */ case class MapObjects private( loopValue: String, From 6065364da697cd29f9b31179063e6cf604aa25ef Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Jul 2016 14:48:04 +0800 Subject: [PATCH 11/11] Address comments. --- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 44 +++++-------------- 2 files changed, 12 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ec154a6d85d55..2a6fcd03a26b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -269,7 +269,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(deserializerFor(_), input, et, dataType), + MapObjects(deserializerFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ec7f3b62fd01d..9621db1d38762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -360,33 +360,7 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, None) - } - - /** - * Construct an instance of MapObjects case class. - * - * @param function The function applied on the collection elements. - * @param inputData An expression that when evaluated returns a collection object. - * @param elementType The data type of elements in the collection. - * @param inputDataType The explicitly given data type of inputData to override the - * data type inferred from inputData (i.e., inputData.dataType). - * When Python UDT whose sqlType is an array, the deserializer - * expression will apply MapObjects on it. However, as the data type - * of inputData is Python UDT, which is not an expected array type - * in MapObjects. In this case, we need to explicitly use - * Python UDT's sqlType as data type. - */ - def apply( - function: Expression => Expression, - inputData: Expression, - elementType: DataType, - inputDataType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - Some(inputDataType)) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) } } @@ -407,16 +381,13 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. - * @param inputDataType The optional dataType of inputData. If it is None, the default behavior is - * to use the resolved data type of the inputData. */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression, - inputDataType: Option[DataType]) extends Expression with NonSQLExpression { + inputData: Expression) extends Expression with NonSQLExpression { override def nullable: Boolean = true @@ -469,7 +440,14 @@ case class MapObjects private( case _ => "" } - val (getLength, getLoopVar) = inputDataType.getOrElse(inputData.dataType) match { + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } + + val (getLength, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -483,7 +461,7 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputDataType.getOrElse(inputData.dataType) match { + val loopNullCheck = inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>