diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 11d6d049f7..859cb13bea 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -70,9 +70,13 @@ object CometCast { case _ => Unsupported } - case (_: DecimalType, _: DecimalType) => - // https://github.com/apache/datafusion-comet/issues/375 - Incompatible() + case (from: DecimalType, to: DecimalType) => + if (to.precision < from.precision) { + // https://github.com/apache/datafusion/issues/13492 + Incompatible(Some("Casting to smaller precision is not supported")) + } else { + Compatible() + } case (DataTypes.StringType, _) => canCastFromString(toType, timeZoneId, evalMode) case (_, DataTypes.StringType) => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index db9a870dc0..d60e3d6d9e 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -895,6 +895,34 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast between decimals with different precision and scale") { + // cast between default Decimal(38, 18) to Decimal(6,2) + val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567")) + val df = withNulls(values) + .toDF("b") + .withColumn("a", col("b").cast(DecimalType(6, 2))) + checkSparkAnswer(df) + } + + test("cast between decimals with higher precision than source") { + // cast between Decimal(10, 2) to Decimal(10,4) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4)) + } + + test("cast between decimals with negative precision") { + // cast to negative scale + checkSparkMaybeThrows( + spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match { + case (expected, actual) => + assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR")) + } + } + + test("cast between decimals with zero precision") { + // cast between Decimal(10, 2) to Decimal(10,0) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0)) + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61c..213ec7efee 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -231,11 +231,9 @@ abstract class CometTestBase df: => DataFrame): (Option[Throwable], Option[Throwable]) = { var expected: Option[Throwable] = None withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val dfSpark = Dataset.ofRows(spark, df.logicalPlan) - expected = Try(dfSpark.collect()).failed.toOption + expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption } - val dfComet = Dataset.ofRows(spark, df.logicalPlan) - val actual = Try(dfComet.collect()).failed.toOption + val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption (expected, actual) }