diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 89890ea08641..88085636a5ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -229,7 +229,7 @@ case class BitwiseCount(child: Expression) override def prettyName: String = "bit_count" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match { - case BooleanType => defineCodeGen(ctx, ev, c => s"if ($c) 1 else 0") + case BooleanType => defineCodeGen(ctx, ev, c => s"($c) ? 1 : 0") case _ => defineCodeGen(ctx, ev, c => s"java.lang.Long.bitCount($c)") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala index 9089c6f17d40..63602d04b5c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -134,6 +134,47 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("BitCount") { + // null + val nullLongLiteral = Literal.create(null, LongType) + val nullIntLiteral = Literal.create(null, IntegerType) + val nullBooleanLiteral = Literal.create(null, BooleanType) + checkEvaluation(BitwiseCount(nullLongLiteral), null) + checkEvaluation(BitwiseCount(nullIntLiteral), null) + checkEvaluation(BitwiseCount(nullBooleanLiteral), null) + + // boolean + checkEvaluation(BitwiseCount(Literal(true)), 1) + checkEvaluation(BitwiseCount(Literal(false)), 0) + + // byte/tinyint + checkEvaluation(BitwiseCount(Literal(1.toByte)), 1) + checkEvaluation(BitwiseCount(Literal(2.toByte)), 1) + checkEvaluation(BitwiseCount(Literal(3.toByte)), 2) + + // short/smallint + checkEvaluation(BitwiseCount(Literal(1.toShort)), 1) + checkEvaluation(BitwiseCount(Literal(2.toShort)), 1) + checkEvaluation(BitwiseCount(Literal(3.toShort)), 2) + + // int + checkEvaluation(BitwiseCount(Literal(1)), 1) + checkEvaluation(BitwiseCount(Literal(2)), 1) + checkEvaluation(BitwiseCount(Literal(3)), 2) + + // long/bigint + checkEvaluation(BitwiseCount(Literal(1L)), 1) + checkEvaluation(BitwiseCount(Literal(2L)), 1) + checkEvaluation(BitwiseCount(Literal(3L)), 2) + + // negative num + checkEvaluation(BitwiseCount(Literal(-1L)), 64) + + // edge value + checkEvaluation(BitwiseCount(Literal(9223372036854775807L)), 63) + checkEvaluation(BitwiseCount(Literal(-9223372036854775808L)), 1) + } + test("BitGet") { val nullLongLiteral = Literal.create(null, LongType) val nullIntLiteral = Literal.create(null, IntegerType)