diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala index e7f0e571804d..8d4acd1608f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PostgreSQLDialect.scala @@ -17,29 +17,28 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.expressions.postgreSQL.PostgreCastToBoolean +import org.apache.spark.sql.catalyst.expressions.postgreSQL._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.types._ object PostgreSQLDialect { - val postgreSQLDialectRules: List[Rule[LogicalPlan]] = - CastToBoolean :: - Nil + val postgreSQLDialectRules: Seq[Rule[LogicalPlan]] = Seq( + PostgresCast + ) - object CastToBoolean extends Rule[LogicalPlan] with Logging { + object PostgresCast extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - // The SQL configuration `spark.sql.dialect` can be changed in runtime. - // To make sure the configuration is effective, we have to check it during rule execution. - val conf = SQLConf.get - if (conf.usePostgreSQLDialect) { + if (SQLConf.get.usePostgreSQLDialect) { plan.transformExpressions { case Cast(child, dataType, timeZoneId) - if child.dataType != BooleanType && dataType == BooleanType => + if dataType == BooleanType && child.dataType != BooleanType => PostgreCastToBoolean(child, timeZoneId) + case Cast(child, dataType, timeZoneId) + if dataType == TimestampType && child.dataType != TimestampType => + PostgreCastToTimestamp(child, timeZoneId) } } else { plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f3b58fa3137b..89d2807f9f1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -410,7 +410,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } // TimestampConverter - private[this] def castToTimestamp(from: DataType): Any => Any = from match { + protected[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull) case BooleanType => @@ -1159,7 +1159,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - private[this] def castToTimestampCode( + protected[this] def castToTimestampCode( from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastBase.scala similarity index 51% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastBase.scala index 20559ba3cd79..94d395d75180 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastToBoolean.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/PostgreCastBase.scala @@ -16,29 +16,52 @@ */ package org.apache.spark.sql.catalyst.expressions.postgreSQL +import java.time.ZoneId + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{CastBase, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.postgreSQL.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String]) - extends CastBase { +abstract class PostgreCastBase(toType: DataType) extends CastBase { - override protected def ansiEnabled = - throw new UnsupportedOperationException("PostgreSQL dialect doesn't support ansi mode") + def fromTypes: TypeCollection - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) + override def dataType: DataType = toType + + override protected def ansiEnabled: Boolean = + throw new UnsupportedOperationException("PostgreSQL dialect doesn't support ansi mode") - override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case StringType | IntegerType | NullType => + override def checkInputDataTypes(): TypeCheckResult = { + if (!fromTypes.acceptsType(child.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"cannot cast type ${child.dataType.simpleString} to ${toType.simpleString}") + } else { TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure(s"cannot cast type ${child.dataType} to boolean") + } } + override def nullable: Boolean = child.nullable + + override def sql: String = s"CAST(${child.sql} AS ${toType.sql})" + + override def toString: String = + s"PostgreCastTo${toType.simpleString}($child as ${toType.simpleString})" +} + +case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String]) + extends PostgreCastBase(BooleanType) { + + override def fromTypes: TypeCollection = TypeCollection(StringType, IntegerType, NullType) + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def castToBoolean(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, str => { @@ -58,7 +81,7 @@ case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String]) override def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" - (c, evPrim, evNull) => + (c, evPrim, _) => code""" if ($stringUtils.isTrueString($c.trim().toLowerCase())) { $evPrim = true; @@ -68,16 +91,48 @@ case class PostgreCastToBoolean(child: Expression, timeZoneId: Option[String]) throw new IllegalArgumentException("invalid input syntax for type boolean: $c"); } """ - case IntegerType => super.castToBooleanCode(from) } +} - override def dataType: DataType = BooleanType +case class PostgreCastToTimestamp(child: Expression, timeZoneId: Option[String]) + extends PostgreCastBase(TimestampType) { - override def nullable: Boolean = child.nullable + override def fromTypes: TypeCollection = TypeCollection(StringType, DateType, NullType) - override def toString: String = s"PostgreCastToBoolean($child as ${dataType.simpleString})" + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) - override def sql: String = s"CAST(${child.sql} AS ${dataType.sql})" + override def castToTimestamp(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs, zoneId) + .getOrElse(throw new + IllegalArgumentException(s"invalid input syntax for type timestamp:$utfs"))) + case DateType => + super.castToTimestamp(from) + } + + override def castToTimestampCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { + case StringType => + val zoneIdClass = classOf[ZoneId] + val zid = JavaCode.global( + ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName), + zoneIdClass) + val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]]) + (c, evPrim, _) => + code""" + scala.Option $longOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $zid); + if ($longOpt.isDefined()) { + $evPrim = ((Long) $longOpt.get()).longValue(); + } else { + throw new IllegalArgumentException(s"invalid input syntax for type timestamp:$c"); + } + """ + case DateType => + super.castToTimestampCode(from, ctx) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala index 6c5218b379f3..2de9ad61448a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/postgreSQL/CastSuite.scala @@ -70,4 +70,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(PostgreCastToBoolean(Literal(1.toDouble), None).checkInputDataTypes().isFailure) assert(PostgreCastToBoolean(Literal(1.toFloat), None).checkInputDataTypes().isFailure) } + + test("unsupported data types to cast to timestamp") { + assert(PostgreCastToTimestamp(Literal(1.toInt), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(1.toByte), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(1.toDouble), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(1.toFloat), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(1.toLong), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(1.toShort), None).checkInputDataTypes().isFailure) + assert(PostgreCastToTimestamp(Literal(BigDecimal(1.0)), None).checkInputDataTypes().isFailure) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/cast.sql new file mode 100644 index 000000000000..918034f1e3ce --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/cast.sql @@ -0,0 +1,10 @@ +SELECT CAST(1 AS timestamp); +SELECT CAST(1.1 AS timestamp); +SELECT CAST(CAST(1 AS float) AS timestamp); +SELECT CAST(CAST(1 AS boolean) AS timestamp); +SELECT CAST(CAST(1 AS byte) AS timestamp); +SELECT CAST(CAST(1 AS short) AS timestamp); +SELECT CAST(CAST(1 AS double) AS timestamp); +SELECT CAST(CAST('2019' AS date) AS timestamp) +SELECT CAST(NULL AS timestamp) +SELECT CAST('2019' AS timestamp) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/cast.sql.out new file mode 100644 index 000000000000..a897b610ddcc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/cast.sql.out @@ -0,0 +1,88 @@ +-- Number of queries: 10 + + +-- !query 0 +SELECT CAST(1 AS timestamp) +-- !query 0 schema +struct<> +-- !query 0 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1 AS TIMESTAMP)' due to data type mismatch: cannot cast type int to timestamp + + +-- !query 1 +SELECT CAST(1.1 AS timestamp) +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(1.1BD AS TIMESTAMP)' due to data type mismatch: cannot cast type decimal(2,1) to timestamp + + +-- !query 2 +SELECT CAST(CAST(1 AS float) AS timestamp) +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(CAST(1 AS FLOAT) AS TIMESTAMP)' due to data type mismatch: cannot cast type float to timestamp + + +-- !query 3 +SELECT CAST(CAST(1 AS boolean) AS timestamp) +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(CAST(1 AS BOOLEAN) AS TIMESTAMP)' due to data type mismatch: cannot cast type boolean to timestamp + + +-- !query 4 +SELECT CAST(CAST(1 AS byte) AS timestamp) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(CAST(1 AS TINYINT) AS TIMESTAMP)' due to data type mismatch: cannot cast type tinyint to timestamp + + +-- !query 5 +SELECT CAST(CAST(1 AS short) AS timestamp) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(CAST(1 AS SMALLINT) AS TIMESTAMP)' due to data type mismatch: cannot cast type smallint to timestamp + + +-- !query 6 +SELECT CAST(CAST(1 AS double) AS timestamp) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'CAST(CAST(1 AS DOUBLE) AS TIMESTAMP)' due to data type mismatch: cannot cast type double to timestamp + + +-- !query 7 +SELECT CAST(CAST('2019' AS date) AS timestamp) +-- !query 7 schema +struct +-- !query 7 output +2019-01-01 00:00:00.0 + + +-- !query 8 +SELECT CAST(NULL AS timestamp) +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT CAST('2019' AS timestamp) +-- !query 9 schema +struct +-- !query 9 output +2019-01-01 00:00:00.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala index 7056f483609a..2f709f776fa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PostgreSQLDialectQuerySuite.scala @@ -39,4 +39,11 @@ class PostgreSQLDialectQuerySuite extends QueryTest with SharedSparkSession { intercept[IllegalArgumentException](sql(s"select cast('$input' as boolean)").collect()) } } + + test("cast to timestamp") { + Seq(1, 0.1, 1.toDouble, 5.toFloat, true, 3.toByte, 4.toShort) foreach { value => + intercept[IllegalArgumentException](sql(s"select cast('$value' as timestamp)").collect()) + } + } } +