Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private[kafka010] object KafkaWriter extends Logging {
s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
} else {
Literal(topic.get, StringType)
Literal.create(topic.get, StringType)
}
).dataType match {
case StringType => // good
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ import org.json4s.JsonAST._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
import org.apache.spark.util.Utils

object Literal {
val TrueLiteral: Literal = Literal(true, BooleanType)
Expand Down Expand Up @@ -196,6 +197,47 @@ object Literal {
case other =>
throw new RuntimeException(s"no default for type $dataType")
}

private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = {
def doValidate(v: Any, dataType: DataType): Boolean = dataType match {
case _ if v == null => true
case BooleanType => v.isInstanceOf[Boolean]
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case LongType | TimestampType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal]
case CalendarIntervalType => v.isInstanceOf[CalendarInterval]
case BinaryType => v.isInstanceOf[Array[Byte]]
case StringType => v.isInstanceOf[UTF8String]
case st: StructType =>
v.isInstanceOf[InternalRow] && {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do the same for array and map?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val row = v.asInstanceOf[InternalRow]
st.fields.map(_.dataType).zipWithIndex.forall {
case (dt, i) => doValidate(row.get(i, dt), dt)
}
}
case at: ArrayType =>
v.isInstanceOf[ArrayData] && {
val ar = v.asInstanceOf[ArrayData]
ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType)
}
case mt: MapType =>
v.isInstanceOf[MapData] && {
val map = v.asInstanceOf[MapData]
doValidate(map.keyArray(), ArrayType(mt.keyType)) &&
doValidate(map.valueArray(), ArrayType(mt.valueType))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether we don't need to check the whole elements for ArrayType and MapType.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the whole element check seems to be expensive, the current one is ok to me.

case ObjectType(cls) => cls.isInstance(v)
case udt: UserDefinedType[_] => doValidate(v, udt.sqlType)
case _ => false
Copy link
Member

@ueshin ueshin Oct 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add NullType case?
nvm, not needed usually.

}
require(doValidate(value, dataType),
s"Literal must have a corresponding value to ${dataType.catalogString}, " +
s"but class ${Utils.getSimpleName(value.getClass)} found.")
}
}

/**
Expand Down Expand Up @@ -240,6 +282,8 @@ object DecimalLiteral {
*/
case class Literal (value: Any, dataType: DataType) extends LeafExpression {

Literal.validateLiteralValue(value, dataType)

override def foldable: Boolean = true
override def nullable: Boolean = value == null

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ class TypeCoercionSuite extends AnalysisTest {
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
val floatLit = Literal.create(1.0f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
val strArrayLit = Literal(Array("c"))
Expand Down Expand Up @@ -793,11 +793,11 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
:: Literal.create(1.0f, FloatType)
:: Nil),
CreateArray(Literal(1.0)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
Expand Down Expand Up @@ -834,23 +834,23 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Cast(Literal(1), FloatType)
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal.create(2.0f, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
:: Literal("a")
:: Literal.create(2.0, FloatType).cast(DoubleType)
:: Literal.create(2.0f, FloatType).cast(DoubleType)
:: Literal("b")
:: Nil))
// type coercion for map values
Expand Down Expand Up @@ -895,11 +895,11 @@ class TypeCoercionSuite extends AnalysisTest {
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
:: Literal.create(1.0f, FloatType)
:: Nil),
operator(Literal(1.0)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Cast(Literal.create(1.0f, FloatType), DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1L)
Expand Down Expand Up @@ -966,7 +966,7 @@ class TypeCoercionSuite extends AnalysisTest {
val falseLit = Literal.create(false, BooleanType)
val stringLit = Literal.create("c", StringType)
val floatLit = Literal.create(1.0f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))

ruleTest(rule,
Expand Down Expand Up @@ -1016,14 +1016,16 @@ class TypeCoercionSuite extends AnalysisTest {
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(100L))),
Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2)))
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with

test("SPARK-21513: to_json support map[string, struct] to json") {
val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil))
val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema)
val input = Literal(
ArrayBasedMapData(Map(UTF8String.fromString("test") -> InternalRow(1))), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"test":{"a":1}}"""
Expand All @@ -633,7 +634,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("SPARK-21513: to_json support map[struct, struct] to json") {
val schema = MapType(StructType(StructField("a", IntegerType) :: Nil),
StructType(StructField("b", IntegerType) :: Nil))
val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
val input = Literal(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"[1]":{"b":2}}"""
Expand All @@ -642,7 +643,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with

test("SPARK-21513: to_json support map[string, integer] to json") {
val schema = MapType(StringType, IntegerType)
val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema)
val input = Literal(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)), schema)
checkEvaluation(
StructsToJson(Map.empty, input),
"""{"a":1}"""
Expand All @@ -651,17 +652,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with

test("to_json - array with maps") {
val inputSchema = ArrayType(MapType(StringType, IntegerType))
val input = new GenericArrayData(ArrayBasedMapData(
Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
val input = new GenericArrayData(
ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) ::
ArrayBasedMapData(Map(UTF8String.fromString("b") -> 2)) :: Nil)
val output = """[{"a":1},{"b":2}]"""
checkEvaluation(
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
StructsToJson(Map.empty, Literal(input, inputSchema), gmtId),
output)
}

test("to_json - array with single map") {
val inputSchema = ArrayType(MapType(StringType, IntegerType))
val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil)
val input = new GenericArrayData(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: Nil)
val output = """[{"a":1}]"""
checkEvaluation(
StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
Expand Down Expand Up @@ -107,8 +109,8 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val nullLit = Literal.create(null, NullType)
val floatNullLit = Literal.create(null, FloatType)
val floatLit = Literal.create(1.01f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val decimalLit = Literal.create(10.2, DecimalType(20, 2))
val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType)
val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2))

assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType)
assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.sql.Timestamp
import java.util.TimeZone

import org.apache.spark.SparkFunSuite
Expand All @@ -32,9 +32,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val b2 = Literal.create(true, BooleanType)
val i1 = Literal.create(20132983, IntegerType)
val i2 = Literal.create(-20132983, IntegerType)
val l1 = Literal.create(20132983, LongType)
val l2 = Literal.create(-20132983, LongType)
val millis = 1524954911000L;
val l1 = Literal.create(20132983L, LongType)
val l2 = Literal.create(-20132983L, LongType)
val millis = 1524954911000L
// Explicitly choose a time zone, since Date objects can create different values depending on
// local time zone of the machine on which the test is running
val oldDefaultTZ = TimeZone.getDefault
Expand All @@ -57,7 +57,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val dec1 = Literal(Decimal(20132983L, 10, 2))
val dec2 = Literal(Decimal(20132983L, 19, 2))
val dec3 = Literal(Decimal(20132983L, 21, 2))
val list1 = Literal(List(1, 2), ArrayType(IntegerType))
val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
val nullVal = Literal.create(null, IntegerType)

checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
}

test("parse sql expression for duration in microseconds - long") {
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType)))
assert(dur.isInstanceOf[Long])
assert(dur === (2 << 52))
assert(dur === (2L << 52))
}

test("parse sql expression for duration in microseconds - invalid interval") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,14 @@ class PercentileSuite extends SparkFunSuite {
BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)

invalidDataTypes.foreach { dataType =>
val percentage = Literal(0.5, dataType)
val percentage = Literal.default(dataType)
val percentile4 = new Percentile(child, percentage)
assertEqual(percentile4.checkInputDataTypes(),
TypeCheckFailure(s"argument 2 requires double type, however, " +
s"'0.5' is of ${dataType.simpleString} type."))
val checkResult = percentile4.checkInputDataTypes()
assert(checkResult.isFailure)
Seq("argument 2 requires double type, however, ",
s"is of ${dataType.simpleString} type.").foreach { errMsg =>
assert(checkResult.asInstanceOf[TypeCheckFailure].message.contains(errMsg))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ case class AnalyzeColumnCommand(
def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
})
val one = Literal(1, LongType)
val one = Literal(1L, LongType)

// the approximate ndv (num distinct value) should never be larger than the number of rows
val numNonNulls = if (col.nullable) Count(col) else Count(one)
val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
val numNulls = Subtract(Count(one), numNonNulls)
val defaultSize = Literal(col.dataType.defaultSize, LongType)
val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType)
val nullArray = Literal(null, ArrayType(LongType))

def fixedLenTypeStruct: CreateNamedStruct = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
test("join key rewritten") {
val l = Literal(1L)
val i = Literal(2)
val s = Literal.create(3, ShortType)
val s = Literal.create(3.toShort, ShortType)
val ss = Literal("hello")

assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
Expand Down