diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md index 816b4eaa0ca63..1247d6e2d5f41 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-keywords.md @@ -273,6 +273,7 @@ Below is a list of all the keywords in Spark SQL. UNCACHEnon-reservednon-reservednon-reserved UNIONreservedstrict-non-reservedreserved UNIQUEreservednon-reservedreserved + UNKNOWNreservednon-reservedreserved UNLOCKnon-reservednon-reservednon-reserved UNSETnon-reservednon-reservednon-reserved USEnon-reservednon-reservednon-reserved diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a1c11504a9036..49adf86c12543 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -660,6 +660,7 @@ predicate | NOT? kind=IN '(' query ')' | NOT? kind=(RLIKE | LIKE) pattern=valueExpression | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) | IS NOT? kind=DISTINCT FROM right=valueExpression ; @@ -1330,6 +1331,7 @@ nonReserved | UNBOUNDED | UNCACHE | UNIQUE + | UNKNOWN | UNLOCK | UNSET | USE @@ -1592,6 +1594,7 @@ UNBOUNDED: 'UNBOUNDED'; UNCACHE: 'UNCACHE'; UNION: 'UNION'; UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; UNLOCK: 'UNLOCK'; UNSET: 'UNSET'; USE: 'USE'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06985ac85b70e..eb91b4c2a6fb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -840,3 +840,85 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } + +trait BooleanTest extends UnaryExpression with Predicate with ExpectsInputTypes { + + def boolValueForComparison: Boolean + def boolValueWhenNull: Boolean + + override def nullable: Boolean = false + override def inputTypes: Seq[DataType] = Seq(BooleanType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + Option(value) match { + case None => boolValueWhenNull + case other => if (boolValueWhenNull) { + value == !boolValueForComparison + } else { + value == boolValueForComparison + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + ev.copy(code = code""" + ${eval.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.value} = $boolValueWhenNull; + } else if ($boolValueWhenNull) { + ${ev.value} = ${eval.value} == !$boolValueForComparison; + } else { + ${ev.value} = ${eval.value} == $boolValueForComparison; + } + """, isNull = FalseLiteral) + } +} + +case class IsTrue(child: Expression) extends BooleanTest { + override def boolValueForComparison: Boolean = true + override def boolValueWhenNull: Boolean = false + override def sql: String = s"(${child.sql} IS TRUE)" +} + +case class IsNotTrue(child: Expression) extends BooleanTest { + override def boolValueForComparison: Boolean = true + override def boolValueWhenNull: Boolean = true + override def sql: String = s"(${child.sql} IS NOT TRUE)" +} + +case class IsFalse(child: Expression) extends BooleanTest { + override def boolValueForComparison: Boolean = false + override def boolValueWhenNull: Boolean = false + override def sql: String = s"(${child.sql} IS FALSE)" +} + +case class IsNotFalse(child: Expression) extends BooleanTest { + override def boolValueForComparison: Boolean = false + override def boolValueWhenNull: Boolean = true + override def sql: String = s"(${child.sql} IS NOT FALSE)" +} + +/** + * IS UNKNOWN and IS NOT UNKNOWN are the same as IS NULL and IS NOT NULL, respectively, + * except that the input expression must be of a boolean type. + */ +object IsUnknown { + def apply(child: Expression): Predicate = { + new IsNull(child) with ExpectsInputTypes { + override def inputTypes: Seq[DataType] = Seq(BooleanType) + override def sql: String = s"(${child.sql} IS UNKNOWN)" + } + } +} + +object IsNotUnknown { + def apply(child: Expression): Predicate = { + new IsNotNull(child) with ExpectsInputTypes { + override def inputTypes: Seq[DataType] = Seq(BooleanType) + override def sql: String = s"(${child.sql} IS NOT UNKNOWN)" + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5eef8dbdfbffc..51b476a643002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1210,6 +1210,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * - (NOT) LIKE * - (NOT) RLIKE * - IS (NOT) NULL. + * - IS (NOT) (TRUE | FALSE | UNKNOWN) * - IS (NOT) DISTINCT FROM */ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { @@ -1243,6 +1244,18 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging IsNotNull(e) case SqlBaseParser.NULL => IsNull(e) + case SqlBaseParser.TRUE => ctx.NOT match { + case null => IsTrue(e) + case _ => IsNotTrue(e) + } + case SqlBaseParser.FALSE => ctx.NOT match { + case null => IsFalse(e) + case _ => IsNotFalse(e) + } + case SqlBaseParser.UNKNOWN => ctx.NOT match { + case null => IsUnknown(e) + case _ => IsNotUnknown(e) + } case SqlBaseParser.DISTINCT if ctx.NOT != null => EqualNullSafe(e, expression(ctx.right)) case SqlBaseParser.DISTINCT => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 9b6896f65abfd..7bff277c793ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -521,4 +521,43 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val expected = "(('id = 1) OR ('id = 2))" assert(expression == expected) } + + val row0 = create_row(null) + val row1 = create_row(false) + val row2 = create_row(true) + + test("istrue and isnottrue") { + checkEvaluation(IsTrue(Literal.create(null, BooleanType)), false, row0) + checkEvaluation(IsNotTrue(Literal.create(null, BooleanType)), true, row0) + checkEvaluation(IsTrue(Literal.create(false, BooleanType)), false, row1) + checkEvaluation(IsNotTrue(Literal.create(false, BooleanType)), true, row1) + checkEvaluation(IsTrue(Literal.create(true, BooleanType)), true, row2) + checkEvaluation(IsNotTrue(Literal.create(true, BooleanType)), false, row2) + IsTrue(Literal.create(null, IntegerType)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("argument 1 requires boolean type")) + } + } + + test("isfalse and isnotfalse") { + checkEvaluation(IsFalse(Literal.create(null, BooleanType)), false, row0) + checkEvaluation(IsNotFalse(Literal.create(null, BooleanType)), true, row0) + checkEvaluation(IsFalse(Literal.create(false, BooleanType)), true, row1) + checkEvaluation(IsNotFalse(Literal.create(false, BooleanType)), false, row1) + checkEvaluation(IsFalse(Literal.create(true, BooleanType)), false, row2) + checkEvaluation(IsNotFalse(Literal.create(true, BooleanType)), true, row2) + IsFalse(Literal.create(null, IntegerType)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("argument 1 requires boolean type")) + } + } + + test("isunknown and isnotunknown") { + checkEvaluation(IsUnknown(Literal.create(null, BooleanType)), true, row0) + checkEvaluation(IsNotUnknown(Literal.create(null, BooleanType)), false, row0) + IsUnknown(Literal.create(null, IntegerType)).checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("argument 1 requires boolean type")) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index ba01380558530..f610a1be5695f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -536,6 +536,7 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "uncache", "union", "unique", + "unknown", "unlock", "unset", "use", @@ -621,6 +622,7 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper { "trailing", "union", "unique", + "unknown", "user", "using", "when",