diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a597a17aadd99..1985e68c94e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,14 +17,25 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Boolean => JavaBoolean} +import java.lang.{Byte => JavaByte} +import java.lang.{Double => JavaDouble} +import java.lang.{Float => JavaFloat} +import java.lang.{Integer => JavaInteger} +import java.lang.{Long => JavaLong} +import java.lang.{Short => JavaShort} +import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter +import scala.math.{BigDecimal, BigInt} + import org.json4s.JsonAST._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -46,12 +57,17 @@ object Literal { case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) - case d: java.math.BigDecimal => + case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case a: Array[_] => + val elementType = componentTypeToDataType(a.getClass.getComponentType()) + val dataType = ArrayType(elementType) + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case v: Literal => v @@ -59,6 +75,45 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Returns the Spark SQL DataType for a given class object. Since this type needs to be resolved + * in runtime, we use match-case idioms for class objects here. However, there are similar + * functions in other files (e.g., HiveInspectors), so these functions need to merged into one. + */ + private[this] def componentTypeToDataType(clz: Class[_]): DataType = clz match { + // primitive types + case JavaShort.TYPE => ShortType + case JavaInteger.TYPE => IntegerType + case JavaLong.TYPE => LongType + case JavaDouble.TYPE => DoubleType + case JavaByte.TYPE => ByteType + case JavaFloat.TYPE => FloatType + case JavaBoolean.TYPE => BooleanType + + // java classes + case _ if clz == classOf[Date] => DateType + case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[Array[Byte]] => BinaryType + case _ if clz == classOf[JavaShort] => ShortType + case _ if clz == classOf[JavaInteger] => IntegerType + case _ if clz == classOf[JavaLong] => LongType + case _ if clz == classOf[JavaDouble] => DoubleType + case _ if clz == classOf[JavaByte] => ByteType + case _ if clz == classOf[JavaFloat] => FloatType + case _ if clz == classOf[JavaBoolean] => BooleanType + + // other scala classes + case _ if clz == classOf[String] => StringType + case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT + case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + + case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) + + case _ => throw new AnalysisException(s"Unsupported component type $clz in arrays") + } + /** * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 450222d8cbba3..4af4da8a9f0c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -43,6 +44,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, TimestampType), null) checkEvaluation(Literal.create(null, CalendarIntervalType), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) } @@ -122,5 +124,28 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } - // TODO(davies): add tests for ArrayType, MapType and StructType + test("array") { + def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { + val toCatalyst = (a: Array[_], elementType: DataType) => { + CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) + } + checkEvaluation(Literal(a), toCatalyst(a, elementType)) + } + checkArrayLiteral(Array(1, 2, 3), IntegerType) + checkArrayLiteral(Array("a", "b", "c"), StringType) + checkArrayLiteral(Array(1.0, 4.0), DoubleType) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + CalendarIntervalType) + } + + test("unsupported types (map and struct) in literals") { + def checkUnsupportedTypeInLiteral(v: Any): Unit = { + val errMsgMap = intercept[RuntimeException] { + Literal(v) + } + assert(errMsgMap.getMessage.startsWith("Unsupported literal type")) + } + checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) + checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) + } }