Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
case TimestampType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
case _ =>
Expand Down Expand Up @@ -633,7 +633,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
}
"""
case TimestampType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) =>
s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);"
case _ =>
Expand Down Expand Up @@ -713,7 +713,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val longOpt = ctx.freshName("longOpt")
(c, evPrim, evNull) =>
s"""
Expand All @@ -730,7 +730,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case _: IntegralType =>
(c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
case DateType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) =>
s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;"
case DecimalType() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ case class ScalaUDF(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {
val scalaUDF = ctx.freshName("scalaUDF")
val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName)
val scalaUDFRef = ctx.addReferenceObj("scalaUDFRef", this, scalaUDFClassName)

// Object to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,28 +109,14 @@ class CodegenContext {
*
* Returns the code to access it.
*
* This is for minor objects not to store the object into field but refer it from the references
* field at the time of use because number of fields in class is limited so we should reduce it.
* This does not to store the object into field but refer it from the references field at the
* time of use because number of fields in class is limited so we should reduce it.
*/
def addReferenceMinorObj(obj: Any, className: String = null): String = {
def addReferenceObj(objName: String, obj: Any, className: String = null): String = {
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
s"(($clsName) references[$idx])"
}

/**
* Add an object to `references`, create a class member to access it.
*
* Returns the name of class member.
*/
def addReferenceObj(name: String, obj: Any, className: String = null): String = {
val term = freshName(name)
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
addMutableState(clsName, term, s"$term = ($clsName) references[$idx];")
term
s"(($clsName) references[$idx] /* $objName */)"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)")
}
Expand Down Expand Up @@ -257,7 +257,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)")
}
Expand Down Expand Up @@ -288,7 +288,7 @@ case class Second(child: Expression, timeZoneId: Option[String] = None)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)")
}
Expand Down Expand Up @@ -529,7 +529,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
defineCodeGen(ctx, ev, (timestamp, format) => {
s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz)
.format(new java.util.Date($timestamp / 1000)))"""
Expand Down Expand Up @@ -691,7 +691,7 @@ abstract class UnixTime
}""")
}
case StringType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (string, format) => {
s"""
Expand All @@ -715,7 +715,7 @@ abstract class UnixTime
${ev.value} = ${eval1.value} / 1000000L;
}""")
case DateType =>
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
Expand Down Expand Up @@ -827,7 +827,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
}""")
}
} else {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (seconds, f) => {
s"""
Expand Down Expand Up @@ -969,7 +969,7 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
Expand Down Expand Up @@ -1065,7 +1065,7 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)"""
Expand Down Expand Up @@ -1143,7 +1143,7 @@ case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Optio
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceMinorObj(timeZone)
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (l, r) => {
s"""$dtu.monthsBetween($l, $r, $tz)"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case FloatType =>
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
val boxedValue = ctx.addReferenceMinorObj(v)
val boxedValue = ctx.addReferenceObj("boxedValue", v)
val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
ev.copy(code = code)
} else {
Expand All @@ -299,7 +299,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case DoubleType =>
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
val boxedValue = ctx.addReferenceMinorObj(v)
val boxedValue = ctx.addReferenceObj("boxedValue", v)
val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
ev.copy(code = code)
} else {
Expand All @@ -309,8 +309,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
ev.copy(code = "", value = s"($javaType)$value")
case TimestampType | LongType =>
ev.copy(code = "", value = s"${value}L")
case other =>
ev.copy(code = "", value = ctx.addReferenceMinorObj(value, ctx.javaType(dataType)))
case _ =>
ev.copy(code = "", value = ctx.addReferenceObj("literal", value,
ctx.javaType(dataType)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa

// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceMinorObj(errMsg)
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = s"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
expressions = childrenCodes,
funcName = "createExternalRow",
extraArguments = "Object[]" -> values :: Nil)
val schemaField = ctx.addReferenceMinorObj(schema)
val schemaField = ctx.addReferenceObj("schema", schema)

val code =
s"""
Expand Down Expand Up @@ -1310,7 +1310,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)

// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null.
val errMsgField = ctx.addReferenceMinorObj(errMsg)
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)

val code = s"""
${childGen.code}
Expand Down Expand Up @@ -1347,7 +1347,7 @@ case class GetExternalRowField(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the field is null.
val errMsgField = ctx.addReferenceMinorObj(errMsg)
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
val code = s"""
${row.code}
Expand Down Expand Up @@ -1387,7 +1387,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
val errMsgField = ctx.addReferenceMinorObj(errMsg)
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val input = child.genCode(ctx)
val obj = input.value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("add" -> Literal(1))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
}

test("SPARK-22716: addReferenceObj should not add mutable states") {
val ctx = new CodegenContext
val foo = new Object()
ctx.addReferenceObj("foo", foo)
assert(ctx.mutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Length(Uuid()), 36)
assert(evaluate(Uuid()) !== evaluate(Uuid()))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
// ScalaUDF can be very verbose and trigger reduceCodeSize
assert(ctx.mutableStates.forall(_._2.startsWith("globalIsNull")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class RowBasedHashMapGenerator(
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
Expand All @@ -60,7 +60,7 @@ class RowBasedHashMapGenerator(
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class VectorizedHashMapGenerator(
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
Expand All @@ -67,7 +67,7 @@ class VectorizedHashMapGenerator(
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
val keyName = ctx.addReferenceMinorObj(key.name)
val keyName = ctx.addReferenceObj("keyName", key.name)
key.dataType match {
case d: DecimalType =>
s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType(
Expand Down