Skip to content

Commit 213ada8

Browse files
committed
First draft of partially aggregated and code generated count distinct / max
1 parent 73ab7f1 commit 213ada8

File tree

9 files changed

+428
-13
lines changed

9 files changed

+428
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
2727
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
2828
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
2929

30-
protected val exprArray = expressions.toArray
30+
// null check is required for when Kryo invokes the no-arg constructor.
31+
protected val exprArray = if (expressions != null) expressions.toArray else null
3132

3233
def apply(input: Row): Row = {
3334
val outputArray = new Array[Any](exprArray.length)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
2222
import org.apache.spark.sql.catalyst.types._
2323
import org.apache.spark.sql.catalyst.trees
2424
import org.apache.spark.sql.catalyst.errors.TreeNodeException
25+
import org.apache.spark.util.collection.OpenHashSet
2526

2627
abstract class AggregateExpression extends Expression {
2728
self: Product =>
@@ -161,13 +162,96 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
161162
override def newInstance() = new CountFunction(child, this)
162163
}
163164

164-
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
165+
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
166+
def this() = this(null)
167+
165168
override def children = expressions
166169
override def references = expressions.flatMap(_.references).toSet
167170
override def nullable = false
168171
override def dataType = LongType
169172
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
170173
override def newInstance() = new CountDistinctFunction(expressions, this)
174+
175+
override def asPartial = {
176+
val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
177+
SplitEvaluation(
178+
CombineSetsAndCount(partialSet.toAttribute),
179+
partialSet :: Nil)
180+
}
181+
}
182+
183+
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
184+
def this() = this(null)
185+
186+
override def children = expressions
187+
override def references = expressions.flatMap(_.references).toSet
188+
override def nullable = false
189+
override def dataType = ArrayType(expressions.head.dataType)
190+
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
191+
override def newInstance() = new CollectHashSetFunction(expressions, this)
192+
}
193+
194+
case class CollectHashSetFunction(
195+
@transient expr: Seq[Expression],
196+
@transient base: AggregateExpression)
197+
extends MergableAggregateFunction {
198+
199+
def this() = this(null, null) // Required for serialization.
200+
201+
val seen = new OpenHashSet[Any]()
202+
203+
@transient
204+
val distinctValue = new InterpretedProjection(expr)
205+
206+
override def merge(other: MergableAggregateFunction): MergableAggregateFunction = {
207+
val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator
208+
while(otherSetIterator.hasNext) {
209+
seen.add(otherSetIterator.next())
210+
}
211+
this
212+
}
213+
214+
override def update(input: Row): Unit = {
215+
val evaluatedExpr = distinctValue(input)
216+
if (!evaluatedExpr.anyNull) {
217+
seen.add(evaluatedExpr)
218+
}
219+
}
220+
221+
override def eval(input: Row): Any = {
222+
seen
223+
}
224+
}
225+
226+
case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
227+
def this() = this(null)
228+
229+
override def children = inputSet :: Nil
230+
override def references = inputSet.references
231+
override def nullable = false
232+
override def dataType = LongType
233+
override def toString = s"CombineAndCount($inputSet)"
234+
override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
235+
}
236+
237+
case class CombineSetsAndCountFunction(
238+
@transient inputSet: Expression,
239+
@transient base: AggregateExpression)
240+
extends AggregateFunction {
241+
242+
def this() = this(null, null) // Required for serialization.
243+
244+
val seen = new OpenHashSet[Any]()
245+
246+
override def update(input: Row): Unit = {
247+
val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
248+
val inputIterator = inputSetEval.iterator
249+
while (inputIterator.hasNext) {
250+
seen.add(inputIterator.next)
251+
}
252+
}
253+
254+
override def eval(input: Row): Any = seen.size.toLong
171255
}
172256

173257
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
@@ -379,17 +463,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
379463
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
380464
}
381465

382-
case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression)
383-
extends AggregateFunction {
466+
case class CountDistinctFunction(
467+
@transient expr: Seq[Expression],
468+
@transient base: AggregateExpression)
469+
extends MergableAggregateFunction {
384470

385471
def this() = this(null, null) // Required for serialization.
386472

387-
val seen = new scala.collection.mutable.HashSet[Any]()
473+
val seen = new OpenHashSet[Any]()
474+
475+
@transient
476+
val distinctValue = new InterpretedProjection(expr)
388477

389478
override def update(input: Row): Unit = {
390-
val evaluatedExpr = expr.map(_.eval(input))
391-
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
392-
seen += evaluatedExpr
479+
val evaluatedExpr = distinctValue(input)
480+
if (!evaluatedExpr.anyNull) {
481+
seen.add(evaluatedExpr)
393482
}
394483
}
395484

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,17 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
8585

8686
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
8787
}
88+
89+
case class MaxOf(left: Expression, right: Expression) extends Expression {
90+
type EvaluatedType = Any
91+
92+
override def nullable = left.nullable && right.nullable
93+
94+
override def children = left :: right :: Nil
95+
96+
override def references = (left.flatMap(_.references) ++ right.flatMap(_.references)).toSet
97+
98+
override def dataType = left.dataType
99+
100+
override def eval(input: Row): Any = ???
101+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.expressions
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.types._
2828

29+
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
30+
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
31+
2932
/**
3033
* A base class for generators of byte code to perform expression evaluation. Includes a set of
3134
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
@@ -71,7 +74,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
7174
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
7275
* fundamental difference is that a ConcurrentMap persists all elements that are added to it until
7376
* they are explicitly removed. A Cache on the other hand is generally configured to evict entries
74-
* automatically, in order to constrain its memory footprint
77+
* automatically, in order to constrain its memory footprint. Note that this cache does not use
78+
* weak keys/values and thus does not respond to memory pressure.
7579
*/
7680
protected val cache = CacheBuilder.newBuilder()
7781
.maximumSize(1000)
@@ -398,6 +402,75 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
398402
$primitiveTerm = ${falseEval.primitiveTerm}
399403
}
400404
""".children
405+
406+
case NewSet(elementType) =>
407+
q"""
408+
val $nullTerm = false
409+
val $primitiveTerm = new ${hashSetForType(elementType)}()
410+
""".children
411+
412+
case AddItemToSet(item, set) =>
413+
val itemEval = expressionEvaluator(item)
414+
val setEval = expressionEvaluator(set)
415+
416+
val ArrayType(elementType, _) = set.dataType
417+
418+
itemEval.code ++ setEval.code ++
419+
q"""
420+
if (!${itemEval.nullTerm}) {
421+
${setEval.primitiveTerm}
422+
.asInstanceOf[${hashSetForType(elementType)}]
423+
.add(${itemEval.primitiveTerm})
424+
}
425+
426+
val $nullTerm = false
427+
val $primitiveTerm = ${setEval.primitiveTerm}
428+
""".children
429+
430+
case CombineSets(left, right) =>
431+
val leftEval = expressionEvaluator(left)
432+
val rightEval = expressionEvaluator(right)
433+
434+
val ArrayType(elementType, _) = left.dataType
435+
436+
leftEval.code ++ rightEval.code ++
437+
q"""
438+
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
439+
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
440+
val iterator = rightSet.iterator
441+
while (iterator.hasNext) {
442+
leftSet.add(iterator.next())
443+
}
444+
445+
val $nullTerm = false
446+
val $primitiveTerm = leftSet
447+
""".children
448+
449+
case MaxOf(e1, e2) =>
450+
val eval1 = expressionEvaluator(e1)
451+
val eval2 = expressionEvaluator(e2)
452+
453+
eval1.code ++ eval2.code ++
454+
q"""
455+
var $nullTerm = false
456+
var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
457+
458+
if (${eval1.nullTerm}) {
459+
$nullTerm = ${eval2.nullTerm}
460+
$primitiveTerm = ${eval2.primitiveTerm}
461+
} else if (${eval2.nullTerm}) {
462+
$nullTerm = ${eval1.nullTerm}
463+
$primitiveTerm = ${eval1.primitiveTerm}
464+
} else {
465+
$nullTerm = false
466+
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
467+
$primitiveTerm = ${eval1.primitiveTerm}
468+
} else {
469+
$primitiveTerm = ${eval2.primitiveTerm}
470+
}
471+
}
472+
""".children
473+
401474
}
402475

403476
// If there was no match in the partial function above, we fall back on calling the interpreted
@@ -437,6 +510,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
437510
protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
438511
protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
439512

513+
protected def hashSetForType(dt: DataType) = dt match {
514+
case IntegerType => typeOf[IntegerHashSet]
515+
case LongType => typeOf[LongHashSet]
516+
}
517+
440518
protected def primitiveForType(dt: DataType) = dt match {
441519
case IntegerType => "Int"
442520
case LongType => "Long"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
6969
..${evaluatedExpression.code}
7070
if(${evaluatedExpression.nullTerm})
7171
setNullAt($iLit)
72-
else
72+
else {
73+
nullBits($iLit) = false
7374
$elementName = ${evaluatedExpression.primitiveTerm}
75+
}
7476
}
7577
""".children : Seq[Tree]
7678
}
@@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
106108
if(value == null) {
107109
setNullAt(i)
108110
} else {
111+
nullBits(i) = false
109112
$elementName = value.asInstanceOf[${termForType(e.dataType)}]
110-
return
111113
}
114+
return
112115
}"""
113116
}
114117
q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
@@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
137140
val elementName = newTermName(s"c$i")
138141
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
139142
// TODO: Optional null checks?
140-
q"if(i == $i) { $elementName = value; return }" :: Nil
143+
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
141144
case _ => Nil
142145
}
143146

0 commit comments

Comments
 (0)