Skip to content
110 changes: 11 additions & 99 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[MapValues] -> CometMapValues,
classOf[MapFromArrays] -> CometMapFromArrays,
classOf[GetMapValue] -> CometMapExtract,
classOf[GreaterThan] -> CometGreaterThan,
classOf[GreaterThanOrEqual] -> CometGreaterThanOrEqual,
classOf[LessThan] -> CometLessThan,
classOf[LessThanOrEqual] -> CometLessThanOrEqual,
classOf[IsNull] -> CometIsNull,
classOf[IsNotNull] -> CometIsNotNull,
classOf[IsNaN] -> CometIsNaN,
classOf[In] -> CometIn,
classOf[InSet] -> CometInSet,
classOf[Rand] -> CometRand,
classOf[Randn] -> CometRandn,
classOf[SparkPartitionID] -> CometSparkPartitionId,
Expand Down Expand Up @@ -675,42 +684,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
binding,
(builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr))

case GreaterThan(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setGt(binaryExpr))

case GreaterThanOrEqual(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setGtEq(binaryExpr))

case LessThan(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setLt(binaryExpr))

case LessThanOrEqual(left, right) =>
createBinaryExpr(
expr,
left,
right,
inputs,
binding,
(builder, binaryExpr) => builder.setLtEq(binaryExpr))

case Literal(value, dataType)
if supportedDataType(dataType, allowComplex = value == null) =>
val exprBuilder = ExprOuterClass.Literal.newBuilder()
Expand Down Expand Up @@ -1065,29 +1038,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
})
optExprWithInfo(optExpr, expr, child)

case IsNull(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNull(unaryExpr))

case IsNotNull(child) =>
createUnaryExpr(
expr,
child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))

case IsNaN(child) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val optExpr =
scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr)

optExprWithInfo(optExpr, expr, child)

case SortOrder(child, direction, nullOrdering, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)

Expand Down Expand Up @@ -1457,20 +1407,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
binding,
(builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))

case In(value, list) =>
in(expr, value, list, inputs, binding, negate = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now why you brought that in function over to comparisons.scala. It was used here, and in InSet and Not(In).

I think we should handle those two expressions in this PR, and bring in over to comparisons.scala so the logic isn't duplicated in two places.


case InSet(value, hset) =>
val valueDataType = value.dataType
val list = hset.map { setVal =>
Literal(setVal, valueDataType)
}.toSeq
// Change `InSet` to `In` expression
// We do Spark `InSet` optimization in native (DataFusion) side.
in(expr, value, list, inputs, binding, negate = false)

case Not(In(value, list)) =>
in(expr, value, list, inputs, binding, negate = true)
case Not(In(_, _)) =>
CometNotIn.convert(expr, inputs, binding)

case Not(child) =>
createUnaryExpr(
Expand Down Expand Up @@ -1814,32 +1752,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
}

def in(
expr: Expression,
value: Expression,
list: Seq[Expression],
inputs: Seq[Attribute],
binding: Boolean,
negate: Boolean): Option[Expr] = {
val valueExpr = exprToProtoInternal(value, inputs, binding)
val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
val builder = ExprOuterClass.In.newBuilder()
builder.setInValue(valueExpr.get)
builder.addAllLists(listExprs.map(_.get).asJava)
builder.setNegated(negate)
Some(
ExprOuterClass.Expr
.newBuilder()
.setIn(builder)
.build())
} else {
val allExprs = list ++ Seq(value)
withInfo(expr, allExprs: _*)
None
}
}

def scalarFunctionExprToProtoWithReturnType(
funcName: String,
returnType: DataType,
Expand Down
208 changes: 208 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/comparisons.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.serde

import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not}
import org.apache.spark.sql.types.BooleanType

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde._

object CometGreaterThan extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val greaterThan = expr.asInstanceOf[GreaterThan]

createBinaryExpr(
expr,
greaterThan.left,
greaterThan.right,
inputs,
binding,
(builder, binaryExpr) => builder.setGt(binaryExpr))
}
}

object CometGreaterThanOrEqual extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val greaterThanOrEqual = expr.asInstanceOf[GreaterThanOrEqual]

createBinaryExpr(
expr,
greaterThanOrEqual.left,
greaterThanOrEqual.right,
inputs,
binding,
(builder, binaryExpr) => builder.setGtEq(binaryExpr))
}
}

object CometLessThan extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val lessThan = expr.asInstanceOf[LessThan]

createBinaryExpr(
expr,
lessThan.left,
lessThan.right,
inputs,
binding,
(builder, binaryExpr) => builder.setLt(binaryExpr))
}
}

object CometLessThanOrEqual extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val lessThanOrEqual = expr.asInstanceOf[LessThanOrEqual]

createBinaryExpr(
expr,
lessThanOrEqual.left,
lessThanOrEqual.right,
inputs,
binding,
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
}
}

object CometIsNull extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val isNull = expr.asInstanceOf[IsNull]

createUnaryExpr(
expr,
isNull.child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNull(unaryExpr))
}
}

object CometIsNotNull extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val isNotNull = expr.asInstanceOf[IsNotNull]

createUnaryExpr(
expr,
isNotNull.child,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
}
}

object CometIsNaN extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val isNaN = expr.asInstanceOf[IsNaN]
val childExpr = exprToProtoInternal(isNaN.child, inputs, binding)
val optExpr = scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr)

optExprWithInfo(optExpr, expr, isNaN.child)
}
}

object CometIn extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val inExpr = expr.asInstanceOf[In]
ComparisonUtils.in(expr, inExpr.value, inExpr.list, inputs, binding, negate = false)
}
}

object CometNotIn extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val notExpr = expr.asInstanceOf[Not]
val inExpr = notExpr.child.asInstanceOf[In]
ComparisonUtils.in(expr, inExpr.value, inExpr.list, inputs, binding, negate = true)
}
}

object CometInSet extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val inSetExpr = expr.asInstanceOf[InSet]
val valueDataType = inSetExpr.child.dataType
val list = inSetExpr.hset.map { setVal =>
Literal(setVal, valueDataType)
}.toSeq
// Change `InSet` to `In` expression
// We do Spark `InSet` optimization in native (DataFusion) side.
ComparisonUtils.in(expr, inSetExpr.child, list, inputs, binding, negate = false)
}
}

object ComparisonUtils {

def in(
expr: Expression,
value: Expression,
list: Seq[Expression],
inputs: Seq[Attribute],
binding: Boolean,
negate: Boolean): Option[Expr] = {
val valueExpr = exprToProtoInternal(value, inputs, binding)
val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
val builder = ExprOuterClass.In.newBuilder()
builder.setInValue(valueExpr.get)
builder.addAllLists(listExprs.map(_.get).asJava)
builder.setNegated(negate)
Some(
ExprOuterClass.Expr
.newBuilder()
.setIn(builder)
.build())
} else {
val allExprs = list ++ Seq(value)
withInfo(expr, allExprs: _*)
None
}
}
}
Loading