Skip to content
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
29fdaba
expressionWithAlias for First, Last, StddevSamp, VarianceSamp
amanomer Dec 9, 2019
7d5f4be
Fixed errors
amanomer Dec 9, 2019
7ba7802
+constructor First
amanomer Dec 10, 2019
759262d
+constructor Last
amanomer Dec 10, 2019
5ec101f
Fixed build error
amanomer Dec 10, 2019
f210bb9
Fixed TC
amanomer Dec 10, 2019
cc234b9
ScalaStyle Fix
amanomer Dec 10, 2019
7617ec3
Fixed TC
amanomer Dec 11, 2019
92e381b
Reduce constructors
amanomer Dec 11, 2019
a2d75de
nit
amanomer Dec 11, 2019
3ef62af
Removed unnecessary changes
amanomer Dec 11, 2019
87c6ea3
Fixed TC
amanomer Dec 11, 2019
9480532
Override flatArguments in VarianceSamp, StddevSamp
amanomer Dec 11, 2019
5c540fc
override flatArguments in First, Last
amanomer Dec 12, 2019
d780dfc
override flatArguments in BoolAnd, BoolOr
amanomer Dec 12, 2019
85d9597
add assert()
amanomer Dec 12, 2019
a71e8a7
expressionWithAlias for Average, ApproximatePercentile & override nod…
amanomer Dec 13, 2019
dd2d85d
Fixes for latest update
amanomer Dec 13, 2019
c1b3afb
Fix ApproximatePercentile TC
amanomer Dec 14, 2019
e7a4e90
UT fix
amanomer Dec 14, 2019
ca886f0
nit
amanomer Dec 18, 2019
125cfac
expressionWithTreeNodeTag for ApproximatePercentile
amanomer Dec 18, 2019
4ca20f4
expressionWithTreeNodeTag for BoolAnd, BoolOr, StddevSamp and Varianc…
amanomer Dec 18, 2019
aecdd8a
expressionWithTreeNodeTag for First, Last and Average
amanomer Dec 18, 2019
bbd4397
Renaming to expressionWithTNT
amanomer Dec 18, 2019
9146913
nit
amanomer Dec 18, 2019
8e9e42b
Avoid duplicate code
amanomer Dec 18, 2019
36418e2
small fix
amanomer Dec 18, 2019
ce8ea17
move FUNC_ALIAS to FunctionRegistry
amanomer Dec 19, 2019
4b536dd
Remove expressionWithAlias
amanomer Dec 19, 2019
737f33a
revert reorder
amanomer Dec 19, 2019
1920940
override prettyName instead of nodeName
amanomer Dec 19, 2019
700a84d
Merge branch 'master' into fncAlias
amanomer Dec 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,10 @@ class Analyzer(
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case First(funcName, expr, _) =>
First(funcName, ifExpr(expr), Literal(true))
case Last(funcName, expr, _) =>
Last(funcName, ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,34 +282,34 @@ object FunctionRegistry {

// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
expressionWithAlias[Average]("avg"),
expressionWithAlias[Average]("mean"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we don't need expressionWithAlias at all, just call expression[Average]("mean", setAlias = true).

And expression[Average]("avg") can remain unchanged, as avg is already the default name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we don't need to reorder them now.

expression[Corr]("corr"),
expression[Count]("count"),
expression[CountIf]("count_if"),
expression[CovPopulation]("covar_pop"),
expression[CovSample]("covar_samp"),
expression[First]("first"),
expression[First]("first_value"),
expressionWithAlias[First]("first"),
expressionWithAlias[First]("first_value"),
expression[Kurtosis]("kurtosis"),
expression[Last]("last"),
expression[Last]("last_value"),
expressionWithAlias[Last]("last"),
expressionWithAlias[Last]("last_value"),
expression[Max]("max"),
expression[MaxBy]("max_by"),
expression[Average]("mean"),
expression[Min]("min"),
expression[MinBy]("min_by"),
expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[ApproximatePercentile]("approx_percentile"),
expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expressionWithAlias[ApproximatePercentile]("percentile_approx"),
expressionWithAlias[ApproximatePercentile]("approx_percentile"),
expressionWithAlias[StddevSamp]("std"),
expressionWithAlias[StddevSamp]("stddev"),
expressionWithAlias[StddevSamp]("stddev_samp"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[VarianceSamp]("variance"),
expressionWithAlias[VarianceSamp]("variance"),
expressionWithAlias[VarianceSamp]("var_samp"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
Expand Down Expand Up @@ -635,7 +635,9 @@ object FunctionRegistry {
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
val constructors = tag.runtimeClass.getConstructors
.filter(_.getParameterTypes.head == classOf[String])
assert(constructors.length == 1)
assert(constructors.length >= 1,
s"there is no constructor for ${tag.runtimeClass} " +
"which takes String as first argument")
val builder = (expressions: Seq[Expression]) => {
val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,13 @@ object TypeCoercion {

case Abs(e @ StringType()) => Abs(Cast(e, DoubleType))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case Average(funcName, e @ StringType()) => Average(funcName, Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case StddevSamp(funcName, e @ StringType()) => StddevSamp(funcName, Cast(e, DoubleType))
case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType))
case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType))
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
case VarianceSamp(funcName, e @ StringType()) => VarianceSamp(funcName, Cast(e, DoubleType))
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
}
Expand Down Expand Up @@ -613,15 +613,15 @@ object TypeCoercion {
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType))
case s @ Average(_, DecimalType()) => s // Decimal is already the biggest.
case Average(funcName, e @ IntegralType()) if e.dataType != LongType =>
Average(funcName, Cast(e, LongType))
case Average(funcName, e @ FractionalType()) if e.dataType != DoubleType =>
Average(funcName, Cast(e, DoubleType))

// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
case Average(funcName, e @ TimestampType()) => Average(funcName, Cast(e, DoubleType))

// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ package object dsl {
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
def avg(e: Expression): Expression = Average(e).toAggregateExpression()
def first(e: Expression): Expression = new First(e).toAggregateExpression()
def last(e: Expression): Expression = new Last(e).toAggregateExpression()
def first(e: Expression): Expression = First(e).toAggregateExpression()
def last(e: Expression): Expression = Last(e).toAggregateExpression()
def min(e: Expression): Expression = Min(e).toAggregateExpression()
def minDistinct(e: Expression): Expression = Min(e).toAggregateExpression(isDistinct = true)
def max(e: Expression): Expression = Max(e).toAggregateExpression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ import org.apache.spark.sql.types._
""",
since = "2.1.0")
case class ApproximatePercentile(
funcName: String,
child: Expression,
percentageExpression: Expression,
accuracyExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
this(child, percentageExpression, accuracyExpression, 0, 0)
def this(
funcName: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

child: Expression,
percentageExpression: Expression,
accuracyExpression: Expression) = {
this(funcName, child, percentageExpression, accuracyExpression, 0, 0)
}

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
def this(funcName: String, child: Expression, percentageExpression: Expression) = {
this(funcName, child, percentageExpression,
Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
}

// Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
Expand Down Expand Up @@ -185,7 +191,7 @@ case class ApproximatePercentile(
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
}

override def prettyName: String = "percentile_approx"
override def nodeName: String = funcName

override def serialize(obj: PercentileDigest): Array[Byte] = {
ApproximatePercentile.serializer.serialize(obj)
Expand All @@ -194,6 +200,10 @@ case class ApproximatePercentile(
override def deserialize(bytes: Array[Byte]): PercentileDigest = {
ApproximatePercentile.serializer.deserialize(bytes)
}

override def flatArguments: Iterator[Any] =
Iterator(child, percentageExpression, accuracyExpression,
mutableAggBufferOffset, inputAggBufferOffset)
}

object ApproximatePercentile {
Expand Down Expand Up @@ -321,4 +331,22 @@ object ApproximatePercentile {
}

val serializer: PercentileDigestSerializer = new PercentileDigestSerializer

def apply(
child: Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 4 space indentation

percentageExpression: Expression,
accuracyExpression: Expression,
mutableAggBufferOffset: Int, inputAggBufferOffset: Int): ApproximatePercentile =
new ApproximatePercentile("percentile_approx", child, percentageExpression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to bring this up so late, but after more thoughts, I feel it's too invasive to add a funcName parameter to expression constructor: we need to add a lot of apply methods and constructors to provide the default function name.

This problem wasn't obvious when we were changing BoolAnd, but it's a bit obvious now for ApproximatePercentile which already have a lot of constructors.

I'd like to propose another idea: using mutable states via TreeNodeTag. The idea is pretty simple: an expression can have an optional alias, which is stored in the expression itself via TreeNodeTag. We can update ApproximatePercentile.prettyName to getTagValue(THE_TAG).getOrElse("percentile_approx").

In expressionWithAlias, it can just call expression, and set the TreeNodeTag with the alias at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a great idea. I will implement and update this PR. Thanks @cloud-fan

accuracyExpression, mutableAggBufferOffset, inputAggBufferOffset)

def apply(
child: Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

percentageExpression: Expression,
accuracyExpression: Expression): ApproximatePercentile = {
new ApproximatePercentile("percentile_approx", child, percentageExpression, accuracyExpression)
}

def apply(child: Expression, percentageExpression: Expression): ApproximatePercentile =
new ApproximatePercentile("percentile_approx", child, percentageExpression)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ import org.apache.spark.sql.types._
-3 days -11 hours -59 minutes -59 seconds
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
case class Average(
funcName: String, child: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

extends DeclarativeAggregate with ImplicitCastInputTypes {

override def prettyName: String = "avg"
override def nodeName: String = funcName

override def children: Seq[Expression] = child :: Nil

Expand Down Expand Up @@ -93,4 +95,10 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
/* count = */ If(child.isNull, count, count + 1L)
)

override def flatArguments: Iterator[Any] = Iterator(child)
}

object Average{
def apply(child: Expression): Average = Average("avg", child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
""",
since = "1.6.0")
// scalastyle:on line.size.limit
case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
case class StddevSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) {

override protected def momentOrder = 2

Expand All @@ -174,7 +174,9 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0))))
}

override def prettyName: String = "stddev_samp"
override def nodeName: String = funcName

override def flatArguments: Iterator[Any] = Iterator(child)
}

// Compute the population variance of a column
Expand Down Expand Up @@ -206,7 +208,7 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
1.0
""",
since = "1.6.0")
case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
case class VarianceSamp(funcName: String, child: Expression) extends CentralMomentAgg(child) {

override protected def momentOrder = 2

Expand All @@ -215,7 +217,9 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
If(n === 1.0, Double.NaN, m2 / (n - 1.0)))
}

override def prettyName: String = "var_samp"
override def nodeName: String = funcName

override def flatArguments: Iterator[Any] = Iterator(child)
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ import org.apache.spark.sql.types._
5
""",
since = "2.0.0")
case class First(child: Expression, ignoreNullsExpr: Expression)
case class First(funcName: String, child: Expression, ignoreNullsExpr: Expression)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(funcName: String, child: Expression) =
this(funcName, child, Literal.create(false, BooleanType))

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil

Expand Down Expand Up @@ -113,5 +114,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression)

override lazy val evaluateExpression: AttributeReference = first

override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
override def nodeName: String = funcName

override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"

override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr)
}

object First {

def apply(child: Expression, ignoreNullsExpr: Expression): First =
new First("first", child, ignoreNullsExpr)

def apply(child: Expression): First =
new First("first", child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ import org.apache.spark.sql.types._
5
""",
since = "2.0.0")
case class Last(child: Expression, ignoreNullsExpr: Expression)
case class Last(funcName: String, child: Expression, ignoreNullsExpr: Expression)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(funcName: String, child: Expression) =
this(funcName, child, Literal.create(false, BooleanType))

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil

Expand Down Expand Up @@ -111,5 +112,18 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)

override lazy val evaluateExpression: AttributeReference = last

override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
override def nodeName: String = funcName

override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"

override def flatArguments: Iterator[Any] = Iterator(child, ignoreNullsExpr)
}

object Last {

def apply(child: Expression, ignoreNullsExpr: Expression): Last =
new Last("last", child, ignoreNullsExpr)

def apply(child: Expression): Last =
new Last("last", child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
case _ => TypeCheckResult.TypeCheckSuccess
}
}

override def flatArguments: Iterator[Any] = Iterator(arg)
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1464,9 +1464,9 @@ object DecimalAggregates extends Rule[LogicalPlan] {
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale)

case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
we.copy(windowFunction = ae.copy(aggregateFunction = Average(f, UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))
Expand All @@ -1477,8 +1477,8 @@ object DecimalAggregates extends Rule[LogicalPlan] {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)

case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
case Average(f, e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(f, UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(SQLConf.get.sessionLocalTimeZone))
Expand Down Expand Up @@ -1539,7 +1539,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
Alias(First(attr).toAggregateExpression(), attr.name)(attr.exprId)
}
}
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
Expand Down
Loading