Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ object DeserializerBuildHelper {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
case _: DecimalType =>
// For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale.
// Here we use the `DecimalType` object to indicate it.
UpCast(expr, DecimalType, walkedTypePath.getPaths)
Comment thread
HyukjinKwon marked this conversation as resolved.
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3071,15 +3071,27 @@ class Analyzer(
case p => p transformExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u

case UpCast(child, dt: AtomicType, _)
case u @ UpCast(child, _, _)
if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
u.dataType.isInstanceOf[AtomicType] &&
child.dataType == StringType =>
Cast(child, dt.asNullable)

case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) =>
fail(child, dataType, walkedTypePath)

case UpCast(child, dataType, _) => Cast(child, dataType.asNullable)
Cast(child, u.dataType.asNullable)

case UpCast(child, target, walkedTypePath)
Comment thread
Ngone51 marked this conversation as resolved.
Outdated
if child.dataType.isInstanceOf[DecimalType]
&& target == DecimalType
&& walkedTypePath.nonEmpty =>
Comment thread
HyukjinKwon marked this conversation as resolved.
Outdated
// SPARK-31750: for the case where data type is explicitly known, e.g, spark.read

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

SPARK-31750: if we want to upcast to the general decimal type, and the `child` is already
decimal type, we can remove the `Upcast` and accept any precision/scale.
This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.

// .parquet("/tmp/file").as[BigDecimal], we will have UpCast(child, Decimal(38, 18)),
// where child's data type can be, e.g. Decimal(38, 0). In this kind of case, we
// actually should not do cast otherwise it will cause precision lost. Thus, we should
// eliminate the UpCast here to avoid precision lost.
child

case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) =>
fail(child, u.dataType, walkedTypePath)

case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1735,8 +1735,15 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St
/**
* Cast the child expression to the target data type, but will throw error if the cast might
* truncate, e.g. long -> int, timestamp -> data.
*
* Note UpCast will be eliminated if the child's dataType is already DecimalType.
Comment thread
HyukjinKwon marked this conversation as resolved.
Outdated
*/
case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil)
case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil)
extends UnaryExpression with Unevaluable {
override lazy val resolved = false

def dataType: DataType = target match {
case DecimalType => DecimalType.SYSTEM_DEFAULT
case _ => target.asInstanceOf[DataType]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -247,6 +247,13 @@ class EncoderResolutionSuite extends PlanTest {
""".stripMargin.trim + " of the field in the target object")
}

test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
val encoder = ExpressionEncoder[Seq[BigDecimal]]
val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))())
// previously, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously -> Before SPARK-31750

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously -> Before SPARK-31750

testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0)))))
}

// test for leaf types
castSuccess[Int, Long]
castSuccess[java.sql.Date, java.sql.Timestamp]
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2439,6 +2439,17 @@ class DataFrameSuite extends QueryTest
val nestedDecArray = Array(decSpark)
checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
}

test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
withTempPath { f =>
sql("select cast(11111111111111111111111111111111111111 as decimal(38, 0)) as d")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test can still reproduce the bug even if we use 1 instead of 1111...?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It depends on the precision/scale rather than the value itself.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it shorter.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test can still reproduce the bug even if we use 1 instead of 1111...?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've changed it to 1 to simplify the test.

.write.mode("overwrite")
.parquet(f.getAbsolutePath)

val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal]
assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0))))
}
}
}

case class GroupByKey(a: Int, b: Int)