Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,23 @@ class CodegenContext {

private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""

/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
def freshName(prefix: String): String = {
s"$prefix${curId.getAndIncrement}"
def freshName(name: String): String = {
if (freshNamePrefix == "") {
s"$name${curId.getAndIncrement}"
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.Utils

/**
* An interface for those physical operators that support codegen.
Expand All @@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null

/**
* Returns an input RDD of InternalRow and Java source code to process them.
* Returns the RDD of InternalRow which generates the input rows.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
def upstream(): RDD[InternalRow]

/**
* Returns Java source code to process the rows from upstream.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = nodeName
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a notion of node id? This is not going to help when we have many joins in one pipeline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Right now we didn't have unique id for SparkPlan.

doProduce(ctx)
}

Expand All @@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
protected def doProduce(ctx: CodegenContext): String

/**
* Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
* Consume the columns generated from current SparkPlan, call it's parent.
*/
protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
assert(columns.length == output.length)
parent.doConsume(ctx, this, columns)
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
if (input != null) {
assert(input.length == output.length)
}
parent.consumeChild(ctx, this, input, row)
}

/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
*/
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = nodeName
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
}

/**
* Generate the Java source code to process the rows from child SparkPlan.
Expand All @@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}
}


Expand All @@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doPrepare(): Unit = {
child.prepare()
}

override def supportCodegen: Boolean = true
override def doExecute(): RDD[InternalRow] = {
child.execute()
}

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def supportCodegen: Boolean = false

override def upstream(): RDD[InternalRow] = {
child.execute()
}

override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
val code = s"""
| while (input.hasNext()) {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| }
""".stripMargin
(child.execute(), code)
}

def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}

override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}

override def simpleString: String = "INPUT"
Expand All @@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
* doExecute() --------> produce()
* doExecute() ---------> upstream() -------> upstream() ------> execute()
* |
* -----------------> produce()
* |
* doProduce() -------> produce()
* |
* doProduce() ---> execute()
* doProduce()
* |
* consume()
* doConsume() ------------|
* consumeChild() <-----------|
* |
* doConsume() <----- consume()
* doConsume()
* |
* consumeChild() <----- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
Expand All @@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {

override def supportCodegen: Boolean = false

override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering

override def doPrepare(): Unit = {
plan.prepare()
}

override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val (rdd, code) = plan.produce(ctx, this)
val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment what references mean? references is a very generic name

return new GeneratedIterator(references);
return new GeneratedIterator(references);
}

class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

private Object[] references;
${ctx.declareMutableStates()}
private Object[] references;
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}

public GeneratedIterator(Object[] references) {
public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
}
}

protected void processNext() {
protected void processNext() throws java.io.IOException {
$code
}
}
}
"""
"""

// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)

rdd.mapPartitions { iter =>
plan.upstream().mapPartitions { iter =>

val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
Expand All @@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}

override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}

override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
} else {
// There is no columns
override def doProduce(ctx: CodegenContext): String = {
throw new UnsupportedOperationException
}

override def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {

if (row != null) {
// There is an UnsafeRow already
s"""
| currentRow = unsafeRow;
| currentRow = $row;
| return;
""".stripMargin
} else {
assert(input != null)
if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
""".stripMargin
}
}
}

Expand All @@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")

plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
Expand Down Expand Up @@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
plan.children.exists(supportCodegen) =>
(Utils.isTesting || plan.children.exists(supportCodegen)) =>

var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
case p if !supportCodegen(p) =>
inputs += p
InputAdapter(p)
val input = apply(p) // collapse them recursively
inputs += input
InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
Expand Down
Loading