diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 91c9457af7de..4e129e96d1c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType /** @@ -72,7 +73,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) - if (!e.nullable) { + if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) { (v: Any) => writer(mutableRow, v) } else { (v: Any) => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0f01bfbb8941..e3f11283816c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,6 +65,68 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } + def testRows( + bufferSchema: StructType, + buffer: InternalRow, + scalaRows: Seq[Seq[Any]]): Unit = { + val bufferTypes = bufferSchema.map(_.dataType).toArray + val proj = createMutableProjection(bufferTypes) + + scalaRows.foreach { scalaRow => + val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj.target(buffer)(inputRow) + assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) + } + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) val scalaValues = Seq("abc", BigDecimal(10),