Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,10 @@ object TypeCoercion {
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
val preferIntegralDivision =
conf.getConf(SQLConf.DIALECT) == SQLConf.Dialect.POSTGRESQL.toString
(left.dataType, right.dataType) match {
case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision =>
case (_: IntegralType, _: IntegralType) if preferIntegralDivision =>
IntegralDivide(left, right)
case _ =>
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
buildCast[UTF8String](_, s => {
if (StringUtils.isTrueString(s)) {
if (StringUtils.isTrueString(s, dialect)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of more dialects, here we can pass the dialect to the StringUtils as a parameter to avoid more future changes in Cast.scala.

Copy link
Member Author

@gengliangwang gengliangwang Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this brings performance overhead.
I will add a new expression instead.

true
} else if (StringUtils.isFalseString(s)) {
} else if (StringUtils.isFalseString(s, dialect)) {
false
} else {
null
Expand Down Expand Up @@ -1250,11 +1251,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
val dialect = SQLConf.get.getConf(SQLConf.DIALECT)
(c, evPrim, evNull) =>
code"""
if ($stringUtils.isTrueString($c)) {
if ($stringUtils.isTrueString($c, "$dialect")) {
$evPrim = true;
} else if ($stringUtils.isFalseString($c)) {
} else if ($stringUtils.isFalseString($c, "$dialect")) {
$evPrim = false;
} else {
$evNull = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,34 @@ object StringUtils extends Logging {
"(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines
}

// "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted.
private[this] val trueStrings =
Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map(UTF8String.fromString)
Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
// "true", "yes", "1", "false", "no", "0", and unique prefixes of these strings are accepted.
private[this] val trueStringsOfPostgreSQL =
Set("true", "tru", "tr", "t", "yes", "ye", "y", "on", "1").map (UTF8String.fromString)

private[this] val falseStrings =
Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
private[this] val falseStringsOfPostgreSQL =
Set("false", "fals", "fal", "fa", "f", "no", "n", "off", "of", "0").map(UTF8String.fromString)

// scalastyle:off caselocale
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase.trim())
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase.trim())
def isTrueString(s: UTF8String, dialect: String): Boolean = {
SQLConf.Dialect.withName(dialect) match {
case SQLConf.Dialect.SPARK =>
trueStrings.contains(s.toLowerCase)
case SQLConf.Dialect.POSTGRESQL =>
trueStringsOfPostgreSQL.contains(s.toLowerCase.trim())
}
}

def isFalseString(s: UTF8String, dialect: String): Boolean = {
SQLConf.Dialect.withName(dialect) match {
case SQLConf.Dialect.SPARK =>
falseStrings.contains(s.toLowerCase)
case SQLConf.Dialect.POSTGRESQL =>
falseStringsOfPostgreSQL.contains(s.toLowerCase.trim())
}
}
// scalastyle:on caselocale

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1589,12 +1589,22 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
.internal()
.doc("When true, will perform integral division with the / operator " +
"if both sides are integral types. This is for PostgreSQL test cases only.")
.booleanConf
.createWithDefault(false)
object Dialect extends Enumeration {
val SPARK, POSTGRESQL = Value
}

val DIALECT =
buildConf("spark.sql.dialect")
.doc("The specific features of the SQL language to be adopted, which are available when " +
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a follow-up PR to add wiki page for the PostgreSQL dialect behaviors.

"accessing the given database. Currently, Spark supports two database dialects, `Spark` " +
"and `PostgreSQL`. With `PostgreSQL` dialect, Spark will: " +
"1. perform integral division with the / operator if both sides are integral types; " +
"2. accept \"true\", \"yes\", \"1\", \"false\", \"no\", \"0\", and unique prefixes as " +
"input and trim input for the boolean data type.")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.checkValues(Dialect.values.map(_.toString))
.createWithDefault(Dialect.SPARK.toString)

val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
Expand Down Expand Up @@ -2418,8 +2428,6 @@ class SQLConf extends Serializable with Logging {

def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)

def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION)

def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1483,15 +1483,15 @@ class TypeCoercionSuite extends AnalysisTest {

test("SPARK-28395 Division operator support integral division") {
val rules = Seq(FunctionArgumentConversion, Division(conf))
Seq(true, false).foreach { preferIntegralDivision =>
withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") {
val result1 = if (preferIntegralDivision) {
Seq(SQLConf.Dialect.SPARK, SQLConf.Dialect.POSTGRESQL).foreach { dialect =>
withSQLConf(SQLConf.DIALECT.key -> dialect.toString) {
val result1 = if (dialect == SQLConf.Dialect.POSTGRESQL) {
IntegralDivide(1L, 1L)
} else {
Divide(Cast(1L, DoubleType), Cast(1L, DoubleType))
}
ruleTest(rules, Divide(1L, 1L), result1)
val result2 = if (preferIntegralDivision) {
val result2 = if (dialect == SQLConf.Dialect.POSTGRESQL) {
IntegralDivide(1, Cast(1, ShortType))
} else {
Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,37 +818,60 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
"interval 1 years 3 months -3 days")
}

test("cast string to boolean") {
checkCast("true", true)
checkCast("tru", true)
checkCast("tr", true)
checkCast("t", true)
checkCast("tRUe", true)
checkCast(" tRue ", true)
checkCast(" tRu ", true)
checkCast("yes", true)
checkCast("ye", true)
checkCast("y", true)
checkCast("1", true)
checkCast("on", true)

checkCast("false", false)
checkCast("fals", false)
checkCast("fal", false)
checkCast("fa", false)
checkCast("f", false)
checkCast(" fAlse ", false)
checkCast(" fAls ", false)
checkCast(" FAlsE ", false)
checkCast("no", false)
checkCast("n", false)
checkCast("0", false)
checkCast("off", false)
checkCast("of", false)

checkEvaluation(cast("o", BooleanType), null)
checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
test("cast string to boolean with Spark dialect") {
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) {
checkCast("t", true)
checkCast("true", true)
checkCast("tRUe", true)
checkCast("y", true)
checkCast("yes", true)
checkCast("1", true)

checkCast("f", false)
checkCast("false", false)
checkCast("FAlsE", false)
checkCast("n", false)
checkCast("no", false)
checkCast("0", false)

checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
}
}

test("cast string to boolean with PostgreSQL dialect") {
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) {
checkCast("true", true)
checkCast("tru", true)
checkCast("tr", true)
checkCast("t", true)
checkCast("tRUe", true)
checkCast(" tRue ", true)
checkCast(" tRu ", true)
checkCast("yes", true)
checkCast("ye", true)
checkCast("y", true)
checkCast("1", true)
checkCast("on", true)

checkCast("false", false)
checkCast("fals", false)
checkCast("fal", false)
checkCast("fa", false)
checkCast("f", false)
checkCast(" fAlse ", false)
checkCast(" fAls ", false)
checkCast(" FAlsE ", false)
checkCast("no", false)
checkCast("n", false)
checkCast("0", false)
checkCast("off", false)
checkCast("of", false)

checkEvaluation(cast("o", BooleanType), null)
checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
}
}

test("SPARK-16729 type checking for casting to date type") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
// PostgreSQL enabled cartesian product by default.
localSparkSession.conf.set(SQLConf.CROSS_JOINS_ENABLED.key, true)
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true)
localSparkSession.conf.set(SQLConf.PREFER_INTEGRAL_DIVISION.key, true)
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true)
localSparkSession.conf.set(SQLConf.DIALECT.key, SQLConf.Dialect.POSTGRESQL.toString)
case _ =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
// PostgreSQL enabled cartesian product by default.
statement.execute(s"SET ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
statement.execute(s"SET ${SQLConf.ANSI_ENABLED.key} = true")
statement.execute(s"SET ${SQLConf.PREFER_INTEGRAL_DIVISION.key} = true")
statement.execute(s"SET ${SQLConf.DIALECT.key} = ${SQLConf.Dialect.POSTGRESQL.toString}")
case _ =>
}

Expand Down