Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,8 @@ object FunctionRegistry {
} else {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
val f = constructors.find(e => e.getParameterTypes.toSeq == params
|| e.getParameterTypes.head == classOf[String]).getOrElse {
Copy link
Contributor

@cloud-fan cloud-fan Dec 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it's less hacky to create a new expressionWithAlias method, with only the necessary logic

def expressionWithAlias ... = {
  val constructors = tag.runtimeClass.getConstructors
    .filter(c => e.getParameterTypes.head == classOf[String])
  assert(constructors.length == 1)
  try {
    constructors.head.newInstance(name, expressions : _*).asInstanceOf[Expression]
  } ...
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we don't even need the MultiNamedExpression trait. We just need to register bool_and, bool_or with expressionWithAlias

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan updated as per your suggestions.

val validParametersCount = constructors
.filter(_.getParameterTypes.forall(_ == classOf[Expression]))
.map(_.getParameterCount).distinct.sorted
Expand All @@ -618,7 +619,13 @@ object FunctionRegistry {
}
throw new AnalysisException(invalidArgumentsMsg)
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
Try{
if (classOf[MultiNamedExpression].isAssignableFrom(f.getDeclaringClass)) {
f.newInstance(name.toString, expressions.head).asInstanceOf[Expression]
} else {
f.newInstance(expressions : _*).asInstanceOf[Expression]
}
} match {
case Success(e) => e
case Failure(e) =>
// the exception is an invocation exception. To get a meaningful message, we need the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
}
}

trait MultiNamedExpression {
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.",
examples = """
Expand All @@ -52,8 +55,9 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
false
""",
since = "3.0.0")
case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_and"
case class BoolAnd(funcName: String, arg: Expression)
extends UnevaluableBooleanAggBase(arg) with MultiNamedExpression {
override def nodeName: String = funcName
}

@ExpressionDescription(
Expand All @@ -68,6 +72,7 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
false
""",
since = "3.0.0")
case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
override def nodeName: String = "bool_or"
case class BoolOr(funcName: String, arg: Expression)
extends UnevaluableBooleanAggBase(arg) with MultiNamedExpression {
override def nodeName: String = funcName
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.child
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
case BoolOr(arg) => Max(arg)
case BoolAnd(arg) => Min(arg)
case BoolOr(_, arg) => Max(arg)
case BoolAnd(_, arg) => Min(arg)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Sum('stringField))
assertSuccess(Average('stringField))
assertSuccess(Min('arrayField))
assertSuccess(new BoolAnd('booleanField))
assertSuccess(new BoolOr('booleanField))
assertSuccess(new BoolAnd("bool_and", 'booleanField))
assertSuccess(new BoolOr("bool_or", 'booleanField))

assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
Expand Down