1717
1818package org .apache .spark .sql .catalyst .expressions
1919
20- import java .sql .Timestamp
20+ import java .sql .{ Date , Timestamp }
2121import java .text .{DateFormat , SimpleDateFormat }
2222
23+ import org .apache .spark .Logging
2324import org .apache .spark .sql .catalyst .types ._
2425
2526/** Cast the child expression to the target data type. */
26- case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression {
27+ case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression with Logging {
2728 override def foldable = child.foldable
2829
2930 override def nullable = (child.dataType, dataType) match {
3031 case (StringType , _ : NumericType ) => true
3132 case (StringType , TimestampType ) => true
33+ case (StringType , DateType ) => true
3234 case _ => child.nullable
3335 }
3436
@@ -42,6 +44,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
4244 // UDFToString
4345 private [this ] def castToString : Any => Any = child.dataType match {
4446 case BinaryType => buildCast[Array [Byte ]](_, new String (_, " UTF-8" ))
47+ case DateType => buildCast[Date ](_, dateToString)
4548 case TimestampType => buildCast[Timestamp ](_, timestampToString)
4649 case _ => buildCast[Any ](_, _.toString)
4750 }
@@ -56,7 +59,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
5659 case StringType =>
5760 buildCast[String ](_, _.length() != 0 )
5861 case TimestampType =>
59- buildCast[Timestamp ](_, b => b.getTime() != 0 || b.getNanos() != 0 )
62+ buildCast[Timestamp ](_, t => t.getTime() != 0 || t.getNanos() != 0 )
63+ case DateType =>
64+ buildCast[Date ](_, d => null )
6065 case LongType =>
6166 buildCast[Long ](_, _ != 0 )
6267 case IntegerType =>
@@ -95,6 +100,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
95100 buildCast[Short ](_, s => new Timestamp (s))
96101 case ByteType =>
97102 buildCast[Byte ](_, b => new Timestamp (b))
103+ case DateType =>
104+ buildCast[Date ](_, d => Timestamp .valueOf(dateToString(d) + " 00:00:00" ))
98105 // TimestampWritable.decimalToTimestamp
99106 case DecimalType =>
100107 buildCast[BigDecimal ](_, d => decimalToTimestamp(d))
@@ -130,7 +137,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
130137 // Converts Timestamp to string according to Hive TimestampWritable convention
131138 private [this ] def timestampToString (ts : Timestamp ): String = {
132139 val timestampString = ts.toString
133- val formatted = Cast .threadLocalDateFormat .get.format(ts)
140+ val formatted = Cast .threadLocalTimestampFormat .get.format(ts)
134141
135142 if (timestampString.length > 19 && timestampString.substring(19 ) != " .0" ) {
136143 formatted + timestampString.substring(19 )
@@ -139,13 +146,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
139146 }
140147 }
141148
149+ // Converts Timestamp to string according to Hive TimestampWritable convention
150+ private [this ] def timestampToDateString (ts : Timestamp ): String = {
151+ Cast .threadLocalDateFormat.get.format(ts)
152+ }
153+
154+ // DateConverter
155+ private [this ] def castToDate : Any => Any = child.dataType match {
156+ case StringType =>
157+ buildCast[String ](_, s => if (s.contains(" " )) {
158+ try castToDate(castToTimestamp(s))
159+ catch { case _ : java.lang.IllegalArgumentException => null }
160+ } else {
161+ try Date .valueOf(s) catch { case _ : java.lang.IllegalArgumentException => null }
162+ })
163+ case TimestampType =>
164+ buildCast[Timestamp ](_, t => Date .valueOf(timestampToDateString(t)))
165+ // TimestampWritable.decimalToDate
166+ case _ =>
167+ _ => null
168+ }
169+
170+ // Date cannot be cast to long, according to hive
171+ private [this ] def dateToLong (d : Date ) = null
172+
173+ // Date cannot be cast to double, according to hive
174+ private [this ] def dateToDouble (d : Date ) = null
175+
176+ // Converts Timestamp to string according to Hive TimestampWritable convention
177+ private [this ] def dateToString (d : Date ): String = {
178+ Cast .threadLocalDateFormat.get.format(d)
179+ }
180+
181+ // LongConverter
142182 private [this ] def castToLong : Any => Any = child.dataType match {
143183 case StringType =>
144184 buildCast[String ](_, s => try s.toLong catch {
145185 case _ : NumberFormatException => null
146186 })
147187 case BooleanType =>
148188 buildCast[Boolean ](_, b => if (b) 1L else 0L )
189+ case DateType =>
190+ buildCast[Date ](_, d => dateToLong(d))
149191 case TimestampType =>
150192 buildCast[Timestamp ](_, t => timestampToLong(t))
151193 case DecimalType =>
@@ -154,13 +196,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
154196 b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
155197 }
156198
199+ // IntConverter
157200 private [this ] def castToInt : Any => Any = child.dataType match {
158201 case StringType =>
159202 buildCast[String ](_, s => try s.toInt catch {
160203 case _ : NumberFormatException => null
161204 })
162205 case BooleanType =>
163206 buildCast[Boolean ](_, b => if (b) 1 else 0 )
207+ case DateType =>
208+ buildCast[Date ](_, d => dateToLong(d))
164209 case TimestampType =>
165210 buildCast[Timestamp ](_, t => timestampToLong(t).toInt)
166211 case DecimalType =>
@@ -169,13 +214,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
169214 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
170215 }
171216
217+ // ShortConverter
172218 private [this ] def castToShort : Any => Any = child.dataType match {
173219 case StringType =>
174220 buildCast[String ](_, s => try s.toShort catch {
175221 case _ : NumberFormatException => null
176222 })
177223 case BooleanType =>
178224 buildCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
225+ case DateType =>
226+ buildCast[Date ](_, d => dateToLong(d))
179227 case TimestampType =>
180228 buildCast[Timestamp ](_, t => timestampToLong(t).toShort)
181229 case DecimalType =>
@@ -184,13 +232,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
184232 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
185233 }
186234
235+ // ByteConverter
187236 private [this ] def castToByte : Any => Any = child.dataType match {
188237 case StringType =>
189238 buildCast[String ](_, s => try s.toByte catch {
190239 case _ : NumberFormatException => null
191240 })
192241 case BooleanType =>
193242 buildCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
243+ case DateType =>
244+ buildCast[Date ](_, d => dateToLong(d))
194245 case TimestampType =>
195246 buildCast[Timestamp ](_, t => timestampToLong(t).toByte)
196247 case DecimalType =>
@@ -199,27 +250,33 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
199250 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
200251 }
201252
253+ // DecimalConverter
202254 private [this ] def castToDecimal : Any => Any = child.dataType match {
203255 case StringType =>
204256 buildCast[String ](_, s => try BigDecimal (s.toDouble) catch {
205257 case _ : NumberFormatException => null
206258 })
207259 case BooleanType =>
208260 buildCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
261+ case DateType =>
262+ buildCast[Date ](_, d => dateToDouble(d))
209263 case TimestampType =>
210264 // Note that we lose precision here.
211265 buildCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
212266 case x : NumericType =>
213267 b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
214268 }
215269
270+ // DoubleConverter
216271 private [this ] def castToDouble : Any => Any = child.dataType match {
217272 case StringType =>
218273 buildCast[String ](_, s => try s.toDouble catch {
219274 case _ : NumberFormatException => null
220275 })
221276 case BooleanType =>
222277 buildCast[Boolean ](_, b => if (b) 1d else 0d )
278+ case DateType =>
279+ buildCast[Date ](_, d => dateToDouble(d))
223280 case TimestampType =>
224281 buildCast[Timestamp ](_, t => timestampToDouble(t))
225282 case DecimalType =>
@@ -228,13 +285,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
228285 b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
229286 }
230287
288+ // FloatConverter
231289 private [this ] def castToFloat : Any => Any = child.dataType match {
232290 case StringType =>
233291 buildCast[String ](_, s => try s.toFloat catch {
234292 case _ : NumberFormatException => null
235293 })
236294 case BooleanType =>
237295 buildCast[Boolean ](_, b => if (b) 1f else 0f )
296+ case DateType =>
297+ buildCast[Date ](_, d => dateToDouble(d))
238298 case TimestampType =>
239299 buildCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
240300 case DecimalType =>
@@ -245,17 +305,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
245305
246306 private [this ] lazy val cast : Any => Any = dataType match {
247307 case dt if dt == child.dataType => identity[Any ]
248- case StringType => castToString
249- case BinaryType => castToBinary
250- case DecimalType => castToDecimal
308+ case StringType => castToString
309+ case BinaryType => castToBinary
310+ case DecimalType => castToDecimal
311+ case DateType => castToDate
251312 case TimestampType => castToTimestamp
252- case BooleanType => castToBoolean
253- case ByteType => castToByte
254- case ShortType => castToShort
255- case IntegerType => castToInt
256- case FloatType => castToFloat
257- case LongType => castToLong
258- case DoubleType => castToDouble
313+ case BooleanType => castToBoolean
314+ case ByteType => castToByte
315+ case ShortType => castToShort
316+ case IntegerType => castToInt
317+ case FloatType => castToFloat
318+ case LongType => castToLong
319+ case DoubleType => castToDouble
259320 }
260321
261322 override def eval (input : Row ): Any = {
@@ -267,6 +328,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
267328object Cast {
268329 // `SimpleDateFormat` is not thread-safe.
269330 private [sql] val threadLocalDateFormat = new ThreadLocal [DateFormat ] {
331+ override def initialValue () = {
332+ new SimpleDateFormat (" yyyy-MM-dd" )
333+ }
334+ }
335+
336+ // `SimpleDateFormat` is not thread-safe.
337+ private [sql] val threadLocalTimestampFormat = new ThreadLocal [DateFormat ] {
270338 override def initialValue () = {
271339 new SimpleDateFormat (" yyyy-MM-dd HH:mm:ss" )
272340 }
0 commit comments