diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index ad9be300f9..29a0f880ab 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -141,6 +141,11 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, classOf[GetMapValue] -> CometMapExtract, + classOf[EqualTo] -> CometEqualTo, + classOf[EqualNullSafe] -> CometEqualNullSafe, + classOf[Not] -> CometNot, + classOf[And] -> CometAnd, + classOf[Or] -> CometOr, classOf[GreaterThan] -> CometGreaterThan, classOf[GreaterThanOrEqual] -> CometGreaterThanOrEqual, classOf[LessThan] -> CometLessThan, @@ -715,42 +720,6 @@ object QueryPlanSerde extends Logging with CometExprShim { case c @ Cast(child, dt, timeZoneId, _) => handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c)) - case EqualTo(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setEq(binaryExpr)) - - case Not(EqualTo(left, right)) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setNeq(binaryExpr)) - - case EqualNullSafe(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) - - case Not(EqualNullSafe(left, right)) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) - case Literal(value, dataType) if supportedDataType( dataType, @@ -955,24 +924,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case And(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setAnd(binaryExpr)) - - case Or(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setOr(binaryExpr)) - case UnaryExpression(child) if expr.prettyName == "promote_precision" => // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 // `PromotePrecision` is just a wrapper, don't need to serialize it. @@ -1162,17 +1113,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case n @ Not(In(_, _)) => - CometNotIn.convert(n, inputs, binding) - - case Not(child) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setNot(unaryExpr)) - case UnaryMinus(child, failOnError) => val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { diff --git a/spark/src/main/scala/org/apache/comet/serde/comparisons.scala b/spark/src/main/scala/org/apache/comet/serde/predicates.scala similarity index 69% rename from spark/src/main/scala/org/apache/comet/serde/comparisons.scala rename to spark/src/main/scala/org/apache/comet/serde/predicates.scala index b0b3d3329e..f4e746c276 100644 --- a/spark/src/main/scala/org/apache/comet/serde/comparisons.scala +++ b/spark/src/main/scala/org/apache/comet/serde/predicates.scala @@ -21,13 +21,109 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or} import org.apache.spark.sql.types.BooleanType import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde._ +object CometNot extends CometExpressionSerde[Not] { + override def convert( + expr: Not, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + + expr.child match { + case expr: EqualTo => + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setNeq(binaryExpr)) + case expr: EqualNullSafe => + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) + case expr: In => + ComparisonUtils.in(expr, expr.value, expr.list, inputs, binding, negate = true) + case _ => + createUnaryExpr( + expr, + expr.child, + inputs, + binding, + (builder, unaryExpr) => builder.setNot(unaryExpr)) + } + } +} + +object CometAnd extends CometExpressionSerde[And] { + override def convert( + expr: And, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setAnd(binaryExpr)) + } +} + +object CometOr extends CometExpressionSerde[Or] { + override def convert( + expr: Or, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setOr(binaryExpr)) + } +} + +object CometEqualTo extends CometExpressionSerde[EqualTo] { + override def convert( + expr: EqualTo, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setEq(binaryExpr)) + } +} + +object CometEqualNullSafe extends CometExpressionSerde[EqualNullSafe] { + override def convert( + expr: EqualNullSafe, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) + } +} + object CometGreaterThan extends CometExpressionSerde[GreaterThan] { override def convert( expr: GreaterThan, @@ -137,16 +233,6 @@ object CometIn extends CometExpressionSerde[In] { } } -object CometNotIn extends CometExpressionSerde[Not] { - override def convert( - expr: Not, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val inExpr = expr.child.asInstanceOf[In] - ComparisonUtils.in(expr, inExpr.value, inExpr.list, inputs, binding, negate = true) - } -} - object CometInSet extends CometExpressionSerde[InSet] { override def convert( expr: InSet,