Skip to content

Commit 2adb235

Browse files
committed
Change timestamp cast semantics. When cast to numeric types, return the unix time in seconds (instead of millis).
1 parent 0307db0 commit 2adb235

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ package object dsl {
104104
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
105105
// TODO more implicit class for literal?
106106
implicit class DslString(val s: String) extends ImplicitOperators {
107-
def expr: Expression = Literal(s)
107+
override def expr: Expression = Literal(s)
108108
def attr = analysis.UnresolvedAttribute(s)
109109
}
110110

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
8787

8888
private def decimalToTimestamp(d: BigDecimal) = {
8989
val seconds = d.longValue()
90-
val bd = (d - seconds) * (1000000000)
90+
val bd = (d - seconds) * 1000000000
9191
val nanos = bd.intValue()
9292

9393
// Convert to millis
@@ -96,18 +96,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
9696

9797
// remaining fractional portion as nanos
9898
t.setNanos(nanos)
99-
10099
t
101100
}
102101

103-
private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)
102+
// Timestamp to long, converting milliseconds to seconds
103+
private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
104+
105+
private def timestampToDouble(ts: Timestamp) = ts.getTime.toDouble / 1000
104106

105107
def castToLong: Any => Any = child.dataType match {
106108
case StringType => nullOrCast[String](_, s => try s.toLong catch {
107109
case _: NumberFormatException => null
108110
})
109111
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
110-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
112+
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
111113
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
112114
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
113115
}
@@ -117,7 +119,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
117119
case _: NumberFormatException => null
118120
})
119121
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
120-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
122+
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt)
121123
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
122124
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
123125
}
@@ -127,7 +129,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
127129
case _: NumberFormatException => null
128130
})
129131
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
130-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
132+
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
131133
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
132134
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
133135
}
@@ -137,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
137139
case _: NumberFormatException => null
138140
})
139141
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
140-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
142+
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
141143
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
142144
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
143145
}
@@ -147,7 +149,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
147149
case _: NumberFormatException => null
148150
})
149151
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
150-
case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
152+
case TimestampType =>
153+
// Note that we lose precision here.
154+
nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
151155
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
152156
}
153157

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,14 @@ class ExpressionEvaluationSuite extends FunSuite {
198198

199199
val sts = "1970-01-01 00:00:01.0"
200200
val ts = Timestamp.valueOf(sts)
201-
201+
202202
checkEvaluation("abdef" cast StringType, "abdef")
203203
checkEvaluation("abdef" cast DecimalType, null)
204204
checkEvaluation("abdef" cast TimestampType, null)
205205
checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
206206

207207
checkEvaluation(Literal(1) cast LongType, 1)
208208
checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1)
209-
checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
210209
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
211210

212211
checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
@@ -237,12 +236,31 @@ class ExpressionEvaluationSuite extends FunSuite {
237236

238237
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
239238
}
240-
239+
241240
test("timestamp") {
242241
val ts1 = new Timestamp(12)
243242
val ts2 = new Timestamp(123)
244243
checkEvaluation(Literal("ab") < Literal("abc"), true)
245244
checkEvaluation(Literal(ts1) < Literal(ts2), true)
246245
}
246+
247+
test("timestamp casting") {
248+
val millis = 15 * 1000 + 1
249+
val ts = new Timestamp(millis)
250+
val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part
251+
checkEvaluation(ts cast ShortType, 15)
252+
checkEvaluation(ts cast IntegerType, 15)
253+
checkEvaluation(ts cast LongType, 15)
254+
checkEvaluation(ts cast FloatType, 15.001f)
255+
checkEvaluation(ts cast DoubleType, 15.001)
256+
checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1)
257+
checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1)
258+
checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1)
259+
checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType),
260+
millis.toFloat / 1000)
261+
checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType),
262+
millis.toDouble / 1000)
263+
checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
264+
}
247265
}
248266

0 commit comments

Comments
 (0)