Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 5c00f3f

Browse files
marmbrusyhuai
authored andcommitted
First draft of codegen
1 parent 6bbc6ba commit 5c00f3f

File tree

5 files changed

+66
-29
lines changed

5 files changed

+66
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
3131
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3232
extends NamedExpression with trees.LeafNode[Expression] {
3333

34-
override def toString: String = s"input[$ordinal]"
34+
override def toString: String = s"input[$ordinal, $dataType]"
3535

3636
override def eval(input: InternalRow): Any = input(ordinal)
3737

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ abstract class Expression extends TreeNode[Expression] {
7272
val primitive = ctx.freshName("primitive")
7373
val ve = GeneratedExpressionCode("", isNull, primitive)
7474
ve.code = genCode(ctx, ve)
75-
ve
75+
ve.copy(s"/* $this */\n" + ve.code)
7676
}
7777

7878
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ private[sql] case object Final extends AggregateMode
3434

3535
private[sql] case object Complete extends AggregateMode
3636

37+
case object NoOp extends Expression {
38+
override def nullable: Boolean = true
39+
override def eval(input: expressions.InternalRow): Any = ???
40+
override def dataType: DataType = NullType
41+
override def children: Seq[Expression] = Nil
42+
}
43+
3744
/**
3845
* A container of a Aggregate Function, Aggregate Mode, and a field (`isDistinct`) indicating
3946
* if DISTINCT keyword is specified for this function.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.expressions.aggregate2.NoOp
2122

2223
// MutableProjection is not accessible in Java
2324
abstract class BaseMutableProjection extends MutableProjection
@@ -36,15 +37,18 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
3637

3738
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
3839
val ctx = newCodeGenContext()
39-
val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
40-
val evaluationCode = e.gen(ctx)
41-
evaluationCode.code +
42-
s"""
43-
if(${evaluationCode.isNull})
44-
mutableRow.setNullAt($i);
45-
else
46-
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
47-
"""
40+
val projectionCode = expressions.zipWithIndex.map {
41+
case (NoOp, _) => ""
42+
case (e, i) =>
43+
val evaluationCode = e.gen(ctx)
44+
evaluationCode.code +
45+
s"""
46+
/** output[$i] = $e */
47+
if(${evaluationCode.isNull})
48+
mutableRow.setNullAt($i);
49+
else
50+
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
51+
"""
4852
}.mkString("\n")
4953
val code = s"""
5054
public Object generate($exprType[] expr) {
@@ -80,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
8084
}
8185
"""
8286

83-
logDebug(s"code for ${expressions.mkString(",")}:\n$code")
87+
logWarning(s"code for ${expressions.mkString(",")}:\n$code")
8488

8589
val c = compile(code)
8690
() => {

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate2._
2424
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, AllTuples, UnspecifiedDistribution, Distribution}
25+
import org.apache.spark.sql.types.NullType
2526

2627
case class Aggregate2Sort(
2728
preShuffle: Boolean,
@@ -66,7 +67,7 @@ case class Aggregate2Sort(
6667
while (i < aggregateExpressions.length) {
6768
val func = aggregateExpressions(i).aggregateFunction.withBufferOffset(bufferOffset)
6869
functions(i) = aggregateExpressions(i).mode match {
69-
case Partial | Complete => BindReferences.bindReference(func, child.output)
70+
case Partial | Complete => func
7071
case PartialMerge | Final => func
7172
}
7273
bufferOffset = aggregateExpressions(i).mode match {
@@ -118,6 +119,43 @@ case class Aggregate2Sort(
118119
new InterpretedMutableProjection(
119120
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)
120121

122+
val offsetAttributes = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
123+
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)
124+
125+
val initialProjection = {
126+
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
127+
case ae: AlgebraicAggregate => ae.initialValues
128+
}
129+
println(initExpressions.mkString(","))
130+
newMutableProjection(initExpressions, Nil)().target(buffer)
131+
}
132+
133+
lazy val updateProjection = {
134+
val bufferSchema = aggregateFunctions.flatMap {
135+
case ae: AlgebraicAggregate => ae.bufferSchema
136+
}
137+
val updateExpressions = aggregateFunctions.flatMap {
138+
case ae: AlgebraicAggregate => ae.updateExpressions
139+
}
140+
141+
println(updateExpressions.mkString(","))
142+
newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
143+
}
144+
145+
val mergeProjection = {
146+
val bufferSchemata =
147+
offsetAttributes ++ aggregateFunctions.flatMap {
148+
case ae: AlgebraicAggregate => ae.bufferSchema
149+
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
150+
case ae: AlgebraicAggregate => ae.rightBufferSchema
151+
}
152+
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
153+
case ae: AlgebraicAggregate => ae.mergeExpressions
154+
}
155+
156+
newMutableProjection(mergeExpressions, bufferSchemata)()
157+
}
158+
121159
// Initialize this iterator.
122160
initialize()
123161

@@ -136,28 +174,16 @@ case class Aggregate2Sort(
136174
}
137175

138176
private def initializeBuffer(): Unit = {
139-
var i = 0
140-
while (i < aggregateFunctions.length) {
141-
aggregateFunctions(i).initialize(buffer)
142-
i += 1
143-
}
177+
initialProjection(EmptyRow)
178+
println("initilized: " + buffer)
144179
}
145180

146181
private def processRow(row: InternalRow): Unit = {
147182
// The new row is still in the current group.
148183
if (preShuffle) {
149-
var i = 0
150-
while (i < aggregateFunctions.length) {
151-
aggregateFunctions(i).update(buffer, row)
152-
i += 1
153-
}
184+
updateProjection(joinedRow(buffer, row))
154185
} else {
155-
var i = 0
156-
println("post shuffle: " + buffer + " " + row)
157-
while (i < aggregateFunctions.length) {
158-
aggregateFunctions(i).merge(buffer, row)
159-
i += 1
160-
}
186+
mergeProjection.target(buffer)(joinedRow(buffer, row))
161187
}
162188
}
163189

0 commit comments

Comments
 (0)