From 4482e1c2b920e201afca1379a3686df9a4db5bc9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 02:32:48 +0900 Subject: [PATCH 1/9] initial commit --- .../spark/sql/catalyst/ScalaReflection.scala | 37 +++++++++++-------- .../sql/catalyst/encoders/RowEncoder.scala | 18 ++++----- .../expressions/objects/objects.scala | 28 +++++++++----- 3 files changed, 49 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 206ae2f0e5eb1..638d1e9f8b744 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val arrayCls = arrayClassFor(elementType) if (elementNullable) { - Invoke(arrayData, "array", arrayCls) + Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => "toIntArray" @@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection { case other => throw new IllegalStateException("expect primitive array element type " + "but got " + other) } - Invoke(arrayData, primitiveMethod, arrayCls) + Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } case t if t <:< localTypeOf[Seq[_]] => @@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection { Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), + returnNullable = false), schemaFor(valueType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) StaticInvoke( ArrayBasedMapData.getClass, @@ -356,7 +361,8 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil, + returnNullable = false) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -365,7 +371,8 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil, + returnNullable = false) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -577,7 +584,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -586,7 +593,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6cb..a2721bdbee6fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -89,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -136,16 +136,16 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( @@ -245,7 +245,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -262,17 +262,17 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) case StringType => - Invoke(input, "toString", ObjectType(classOf[String])) + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), "array", - ObjectType(classOf[Array[_]])) + ObjectType(classOf[Array[_]]), returnNullable = false) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 00e2ac91e67ca..61c3e20c591bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,20 +225,28 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") - s""" - Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { + if (!returnNullable) { + s""" + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } - """ + """ + } else { + s""" + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + if ($funcResult == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } + """ + } } // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { + val postNullCheck = if (ctx.defaultValue(dataType) == "null" && returnNullable) { s"${ev.isNull} = ${ev.value} == null;" } else { "" @@ -608,7 +616,7 @@ case class MapObjects private( $convertedArray = $arrayConstructor; """, genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);" + s"new ${classOf[GenericArrayData].getName}($convertedArray); /*###*/" ) } From ae5e232da543f6c7c5d6f6a3526bdb56c6f793b8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 03:41:29 +0900 Subject: [PATCH 2/9] fix scala style errors --- .../apache/spark/sql/catalyst/encoders/RowEncoder.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a2721bdbee6fa..1021f0b481f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -136,14 +136,16 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), returnNullable = false), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), returnNullable = false), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) @@ -262,7 +264,8 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case StringType => Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) From fc6caacf5fca8cd89b1e324540761ae23f88d9d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 16:22:33 +0900 Subject: [PATCH 3/9] address review comments --- .../expressions/objects/objects.scala | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 61c3e20c591bb..ec7c984f7ffbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,23 +225,11 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") - if (!returnNullable) { - s""" - Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - """ - } else { - s""" - Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } - """ - } + s""" + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + """ } // If the function can return null, we do an extra check to make sure our null bit is still set @@ -616,7 +604,7 @@ case class MapObjects private( $convertedArray = $arrayConstructor; """, genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray); /*###*/" + s"new ${classOf[GenericArrayData].getName}($convertedArray);" ) } From 41c96ab5c043285fe88060ed57cbd56501bc63b6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 17:56:40 +0900 Subject: [PATCH 4/9] address review comments --- .../spark/sql/catalyst/ScalaReflection.scala | 10 ++++------ .../catalyst/expressions/objects/objects.scala | 17 ++++++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 638d1e9f8b744..198122759e4ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -361,8 +361,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil, - returnNullable = false) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -371,8 +370,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil, - returnNullable = false) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -584,7 +582,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -593,7 +591,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ec7c984f7ffbe..efa73ba382507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,21 +225,21 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val postNullCheck = if (returnNullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + $postNullCheck """ } - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null" && returnNullable) { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - val code = s""" ${obj.code} boolean ${ev.isNull} = true; @@ -250,7 +250,6 @@ case class Invoke( if (!${ev.isNull}) { $evaluate } - $postNullCheck } """ ev.copy(code = code) From a39803ab0f77124add833bebb3cb0353306aa1f2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 17:57:07 +0900 Subject: [PATCH 5/9] add test suites --- .../org/apache/spark/sql/DatasetPrimitiveSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 82b707537e45f..541565344f758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(dsBoolean.map(e => !e), false, true) } + test("mapPrimitiveArray") { + val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS() + checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4)) + checkDataset(dsInt.map(e => null: Array[Int]), null, null) + + val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS() + checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D)) + checkDataset(dsDouble.map(e => null: Array[Double]), null, null) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( From 895a3be09d674bf9c57969ca6ad08129df544be8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 19:47:52 +0900 Subject: [PATCH 6/9] address review comment --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 1021f0b481f10..0f8282d3b2f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -247,7 +247,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil, returnNullable = false) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) case TimestampType => StaticInvoke( From 3080ac2230e2512d6de3f6aadfed0e31b3b7eed3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 8 Apr 2017 19:48:28 +0900 Subject: [PATCH 7/9] fix test failure --- .../sql/catalyst/expressions/objects/objects.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index efa73ba382507..d86e9f13ce7b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -227,15 +227,20 @@ case class Invoke( val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. - val postNullCheck = if (returnNullable) { - s"${ev.isNull} = ${ev.value} == null;" + val postNullCheck = if (!returnNullable) { + s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" } else { - "" + s""" + if ($funcResult != null) { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; $postNullCheck """ } From 510fb530ebf3d9235206cefe8e428bf3f8689cfc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 9 Apr 2017 00:04:13 +0900 Subject: [PATCH 8/9] address review comment --- .../spark/sql/catalyst/expressions/objects/objects.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d86e9f13ce7b4..023d347aa5f6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -227,7 +227,7 @@ case class Invoke( val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. - val postNullCheck = if (!returnNullable) { + val postNullCheckAndAssign = if (!returnNullable) { s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" } else { s""" @@ -241,7 +241,7 @@ case class Invoke( s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - $postNullCheck + $postNullCheckAndAssign """ } From 10cf4be41d1de37115edc140e1421caf5b23336a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 9 Apr 2017 13:40:17 +0900 Subject: [PATCH 9/9] adress review comment --- .../spark/sql/catalyst/expressions/objects/objects.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 023d347aa5f6a..ca9a8cf564fa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -227,7 +227,7 @@ case class Invoke( val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. - val postNullCheckAndAssign = if (!returnNullable) { + val assignResult = if (!returnNullable) { s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" } else { s""" @@ -241,7 +241,7 @@ case class Invoke( s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - $postNullCheckAndAssign + $assignResult """ }