From 55e1b720611b09bb85aebd9037883dd863afabf6 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 2 Dec 2022 17:06:24 -0800 Subject: [PATCH 1/4] possible fix and test --- .../InterpretedMutableProjection.scala | 27 ++++++++++++++----- .../expressions/MutableProjectionSuite.scala | 22 +++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 91c9457af7de..56c5fe19977d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{Decimal, DecimalType} /** @@ -72,16 +73,28 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) + val nullSafeWriter: (InternalRow, Any) => Unit = e.dataType match { + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + (row, v: Any) => { + if (v == null && !row.isInstanceOf[UnsafeRow]) { + row.setNullAt(i) + } else { + writer(row, v) + } + } + case _ => + (row, v: Any) => { + if (v == null) { + row.setNullAt(i) + } else { + writer(row, v) + } + } + } if (!e.nullable) { (v: Any) => writer(mutableRow, v) } else { - (v: Any) => { - if (v == null) { - mutableRow.setNullAt(i) - } else { - writer(mutableRow, v) - } - } + (v: Any) => nullSafeWriter(mutableRow, v) } }.toArray diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0f01bfbb8941..6c1b3ff750fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,6 +65,28 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } + testBothCodegenAndInterpreted("unsafe buffer with null decimal") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val bufferTypes = Array[DataType](DecimalType(27, 2), DecimalType(27, 2)) + val proj = createMutableProjection(bufferTypes) + val unsafeBuffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + + val scalaRows = Seq( + Seq(BigDecimal(5), null), + Seq(BigDecimal(10), BigDecimal(11))) + + scalaRows.foreach { scalaRow => + val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj.target(unsafeBuffer)(inputRow) + assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) + } + } + testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) val scalaValues = Seq("abc", BigDecimal(10), From 34b764a9dd64ef459c4ce38b18155889263924d4 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 4 Dec 2022 18:40:00 -0800 Subject: [PATCH 2/4] update --- .../InterpretedMutableProjection.scala | 41 +++++++++++-------- .../expressions/MutableProjectionSuite.scala | 4 +- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 56c5fe19977d..e84a986d1596 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -58,6 +58,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable case _ => true } private[this] var mutableRow: InternalRow = new GenericInternalRow(expressions.size) + private[this] var unsafeMutableRow = false def currentValue: InternalRow = mutableRow override def target(row: InternalRow): MutableProjection = { @@ -68,32 +69,36 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable validExprs.map(_._1.dataType).filterNot(UnsafeRow.isMutable) .map(_.catalogString).mkString(", ")) mutableRow = row + unsafeMutableRow = if (mutableRow.isInstanceOf[UnsafeRow]) true else false; this } private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) - val nullSafeWriter: (InternalRow, Any) => Unit = e.dataType match { - case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => - (row, v: Any) => { - if (v == null && !row.isInstanceOf[UnsafeRow]) { - row.setNullAt(i) - } else { - writer(row, v) - } - } - case _ => - (row, v: Any) => { - if (v == null) { - row.setNullAt(i) - } else { - writer(row, v) - } - } - } if (!e.nullable) { (v: Any) => writer(mutableRow, v) } else { + val nullSafeWriter: (InternalRow, Any) => Unit = e.dataType match { + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + (row, v: Any) => { + // The check of unsafeMutableRow has to happen at run time, rather than at + // field writer creation time, because `InterpretedMutableProjection#target` + // can be called after the field writers are created. + if (v == null && !unsafeMutableRow) { + row.setNullAt(i) + } else { + writer(row, v) + } + } + case _ => + (row, v: Any) => { + if (v == null) { + row.setNullAt(i) + } else { + writer(row, v) + } + } + } (v: Any) => nullSafeWriter(mutableRow, v) } }.toArray diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 6c1b3ff750fb..0c933cc8528a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -75,8 +75,8 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { .apply(new GenericInternalRow(bufferSchema.length)) val scalaRows = Seq( - Seq(BigDecimal(5), null), - Seq(BigDecimal(10), BigDecimal(11))) + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) scalaRows.foreach { scalaRow => val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { From ea338f0ead192b34deaabadc50cffa425f26e01b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 5 Dec 2022 12:01:48 -0800 Subject: [PATCH 3/4] update --- .../expressions/MutableProjectionSuite.scala | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0c933cc8528a..bcaa335bafeb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,7 +65,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } - testBothCodegenAndInterpreted("unsafe buffer with null decimal") { + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { val bufferSchema = StructType(Array( StructField("dec1", DecimalType(27, 2), nullable = true), StructField("dec2", DecimalType(27, 2), nullable = true))) @@ -87,6 +87,28 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { } } + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val bufferTypes = Array[DataType](DecimalType(10, 2), DecimalType(10, 2)) + val proj = createMutableProjection(bufferTypes) + val unsafeBuffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + + scalaRows.foreach { scalaRow => + val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj.target(unsafeBuffer)(inputRow) + assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) + } + } + testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) val scalaValues = Seq("abc", BigDecimal(10), From 484a7db04e016be5fdfb8754f1329e3b85993115 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 6 Dec 2022 13:51:36 -0800 Subject: [PATCH 4/4] Simplify based on what codegen does --- .../InterpretedMutableProjection.scala | 33 +++------- .../expressions/MutableProjectionSuite.scala | 62 ++++++++++++------- 2 files changed, 48 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index e84a986d1596..4e129e96d1c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType /** @@ -58,7 +58,6 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable case _ => true } private[this] var mutableRow: InternalRow = new GenericInternalRow(expressions.size) - private[this] var unsafeMutableRow = false def currentValue: InternalRow = mutableRow override def target(row: InternalRow): MutableProjection = { @@ -69,37 +68,21 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable validExprs.map(_._1.dataType).filterNot(UnsafeRow.isMutable) .map(_.catalogString).mkString(", ")) mutableRow = row - unsafeMutableRow = if (mutableRow.isInstanceOf[UnsafeRow]) true else false; this } private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) - if (!e.nullable) { + if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) { (v: Any) => writer(mutableRow, v) } else { - val nullSafeWriter: (InternalRow, Any) => Unit = e.dataType match { - case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => - (row, v: Any) => { - // The check of unsafeMutableRow has to happen at run time, rather than at - // field writer creation time, because `InterpretedMutableProjection#target` - // can be called after the field writers are created. - if (v == null && !unsafeMutableRow) { - row.setNullAt(i) - } else { - writer(row, v) - } - } - case _ => - (row, v: Any) => { - if (v == null) { - row.setNullAt(i) - } else { - writer(row, v) - } - } + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(mutableRow, v) + } } - (v: Any) => nullSafeWriter(mutableRow, v) } }.toArray diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index bcaa335bafeb..e3f11283816c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,48 +65,66 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } - testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { - val bufferSchema = StructType(Array( - StructField("dec1", DecimalType(27, 2), nullable = true), - StructField("dec2", DecimalType(27, 2), nullable = true))) - val bufferTypes = Array[DataType](DecimalType(27, 2), DecimalType(27, 2)) + def testRows( + bufferSchema: StructType, + buffer: InternalRow, + scalaRows: Seq[Seq[Any]]): Unit = { + val bufferTypes = bufferSchema.map(_.dataType).toArray val proj = createMutableProjection(bufferTypes) - val unsafeBuffer = UnsafeProjection.create(bufferSchema) - .apply(new GenericInternalRow(bufferSchema.length)) - - val scalaRows = Seq( - Seq(null, null), - Seq(BigDecimal(77.77), BigDecimal(245.00))) scalaRows.foreach { scalaRow => val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) }) - val projRow = proj.target(unsafeBuffer)(inputRow) + val projRow = proj.target(buffer)(inputRow) assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) } } + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (low precision)") { val bufferSchema = StructType(Array( StructField("dec1", DecimalType(10, 2), nullable = true), StructField("dec2", DecimalType(10, 2), nullable = true))) - val bufferTypes = Array[DataType](DecimalType(10, 2), DecimalType(10, 2)) - val proj = createMutableProjection(bufferTypes) - val unsafeBuffer = UnsafeProjection.create(bufferSchema) + val buffer = UnsafeProjection.create(bufferSchema) .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) val scalaRows = Seq( Seq(null, null), Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } - scalaRows.foreach { scalaRow => - val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { - case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) - }) - val projRow = proj.target(unsafeBuffer)(inputRow) - assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) - } + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) } testBothCodegenAndInterpreted("variable-length types") {