diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 04e8963944fd..52e05b820366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -527,6 +527,7 @@ object FunctionRegistry { expression[BitwiseCount]("bit_count"), expression[BitAndAgg]("bit_and"), expression[BitOrAgg]("bit_or"), + expression[BitXorAgg]("bit_xor"), // json expression[StructsToJson]("to_json"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala index 131fa2eb5055..b77c3bd9cbde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala @@ -17,20 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryArithmetic, BitwiseAnd, BitwiseOr, BitwiseXor, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal} import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType} -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.", - examples = """ - Examples: - > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); - 1 - """, - since = "3.0.0") -case class BitAndAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes { +abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes { - override def nodeName: String = "bit_and" + val child: Expression + + def bitOperator(left: Expression, right: Expression): BinaryArithmetic override def children: Seq[Expression] = child :: Nil @@ -40,23 +34,40 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) - private lazy val bitAnd = AttributeReference("bit_and", child.dataType)() - - override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAnd :: Nil + private lazy val bitAgg = AttributeReference(nodeName, child.dataType)() override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAgg :: Nil + + override lazy val evaluateExpression: AttributeReference = bitAgg + override lazy val updateExpressions: Seq[Expression] = - If(IsNull(bitAnd), + If(IsNull(bitAgg), child, - If(IsNull(child), bitAnd, BitwiseAnd(bitAnd, child))) :: Nil + If(IsNull(child), bitAgg, bitOperator(bitAgg, child))) :: Nil override lazy val mergeExpressions: Seq[Expression] = - If(IsNull(bitAnd.left), - bitAnd.right, - If(IsNull(bitAnd.right), bitAnd.left, BitwiseAnd(bitAnd.left, bitAnd.right))) :: Nil + If(IsNull(bitAgg.left), + bitAgg.right, + If(IsNull(bitAgg.right), bitAgg.left, bitOperator(bitAgg.left, bitAgg.right))) :: Nil +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); + 1 + """, + since = "3.0.0") +case class BitAndAgg(child: Expression) extends BitAggregate { - override lazy val evaluateExpression: AttributeReference = bitAnd + override def nodeName: String = "bit_and" + + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseAnd(left, right) + } } @ExpressionDescription( @@ -67,33 +78,28 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect 7 """, since = "3.0.0") -case class BitOrAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes { +case class BitOrAgg(child: Expression) extends BitAggregate { override def nodeName: String = "bit_or" - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType - - override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) - - private lazy val bitOr = AttributeReference("bit_or", child.dataType)() - - override lazy val aggBufferAttributes: Seq[AttributeReference] = bitOr :: Nil - - override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseOr(left, right) + } +} - override lazy val updateExpressions: Seq[Expression] = - If(IsNull(bitOr), - child, - If(IsNull(child), bitOr, BitwiseOr(bitOr, child))) :: Nil +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the bitwise XOR of all non-null input values, or null if none.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col); + 6 + """, + since = "3.0.0") +case class BitXorAgg(child: Expression) extends BitAggregate { - override lazy val mergeExpressions: Seq[Expression] = - If(IsNull(bitOr.left), - bitOr.right, - If(IsNull(bitOr.right), bitOr.left, BitwiseOr(bitOr.left, bitOr.right))) :: Nil + override def nodeName: String = "bit_xor" - override lazy val evaluateExpression: AttributeReference = bitOr + override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = { + BitwiseXor(left, right) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql index 993eecf0f89b..5e665e4c0c38 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql @@ -37,3 +37,34 @@ select bit_count(-9223372036854775808L); -- other illegal arguments select bit_count("bit count"); select bit_count('a'); + +-- test for bit_xor +-- +CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES + (1, 1, 1, 1L), + (2, 3, 4, null), + (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4); + +-- empty case +SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0; + +-- null case +SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null; + +-- the suffix numbers show the expected answer +SELECT + BIT_XOR(cast(b1 as tinyint)) AS a4, + BIT_XOR(cast(b2 as smallint)) AS b5, + BIT_XOR(b3) AS c2, + BIT_XOR(b4) AS d2, + BIT_XOR(distinct b4) AS e2 +FROM bitwise_test; + +-- group by +SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1; + +--having +SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7; + +-- window +SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test; diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out index 7cbd26e87bd2..42c22a317eb4 100644 --- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 27 -- !query 0 @@ -162,3 +162,72 @@ struct<> -- !query 19 output org.apache.spark.sql.AnalysisException cannot resolve 'bit_count('a')' due to data type mismatch: argument 1 requires (integral or boolean) type, however, ''a'' is of string type.; line 1 pos 7 + + +-- !query 20 +CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES + (1, 1, 1, 1L), + (2, 3, 4, null), + (7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4) +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0 +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +SELECT + BIT_XOR(cast(b1 as tinyint)) AS a4, + BIT_XOR(cast(b2 as smallint)) AS b5, + BIT_XOR(b3) AS c2, + BIT_XOR(b4) AS d2, + BIT_XOR(distinct b4) AS e2 +FROM bitwise_test +-- !query 23 schema +struct +-- !query 23 output +4 5 2 2 2 + + +-- !query 24 +SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1 +-- !query 24 schema +struct +-- !query 24 output +4 +6 + + +-- !query 25 +SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7 +-- !query 25 schema +struct +-- !query 25 output +1 1 +2 3 + + +-- !query 26 +SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test +-- !query 26 schema +struct +-- !query 26 output +1 1 1 +2 3 3 +7 7 7