-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15214][SQL] Code-generation for Generate #13065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5254559
9df56a9
5d068b5
b2d663b
e04d66f
43a04bf
7b4772d
b721b60
09513e7
f7c2307
f5bd9cf
dba4240
49f9e7f
b3531cb
5cfba19
f86da0f
60da24e
87688b1
1d2d595
c9b3eda
2732b06
36cd826
5b3d9bd
3a40952
c41e308
116339a
2c6c7f2
d20114b
ad36de5
757b470
8c14194
459714c
7b7fa6e
ebd9d8c
f81eed7
29c606a
af9a516
3146cc5
ffd5ef8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,12 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
@@ -60,6 +62,26 @@ trait Generator extends Expression { | |
| * rows can be made here. | ||
| */ | ||
| def terminate(): TraversableOnce[InternalRow] = Nil | ||
|
|
||
| /** | ||
| * Check if this generator supports code generation. | ||
| */ | ||
| def supportCodegen: Boolean = !isInstanceOf[CodegenFallback] | ||
| } | ||
|
|
||
| /** | ||
| * A collection producing [[Generator]]. This trait provides a different path for code generation, | ||
| * by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object. | ||
| */ | ||
| trait CollectionGenerator extends Generator { | ||
| /** The position of an element within the collection should also be returned. */ | ||
| def position: Boolean | ||
|
|
||
| /** Rows will be inlined during generation. */ | ||
| def inline: Boolean | ||
|
|
||
| /** The type of the returned collection object. */ | ||
| def collectionType: DataType = dataType | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -77,7 +99,9 @@ case class UserDefinedGenerator( | |
| private def initializeConverters(): Unit = { | ||
| inputRow = new InterpretedProjection(children) | ||
| convertToScala = { | ||
| val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) | ||
| val inputSchema = StructType(children.map { e => | ||
| StructField(e.simpleString, e.dataType, nullable = true) | ||
| }) | ||
| CatalystTypeConverters.createToScalaConverter(inputSchema) | ||
| }.asInstanceOf[InternalRow => Row] | ||
| } | ||
|
|
@@ -109,8 +133,7 @@ case class UserDefinedGenerator( | |
| 1 2 | ||
| 3 NULL | ||
| """) | ||
| case class Stack(children: Seq[Expression]) | ||
| extends Expression with Generator with CodegenFallback { | ||
| case class Stack(children: Seq[Expression]) extends Generator { | ||
|
|
||
| private lazy val numRows = children.head.eval().asInstanceOf[Int] | ||
| private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt | ||
|
|
@@ -149,29 +172,58 @@ case class Stack(children: Seq[Expression]) | |
| InternalRow(fields: _*) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Only support code generation when stack produces 50 rows or less. | ||
| */ | ||
| override def supportCodegen: Boolean = numRows <= 50 | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| // Rows - we write these into an array. | ||
| val rowData = ctx.freshName("rows") | ||
| ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") | ||
| val values = children.tail | ||
| val dataTypes = values.take(numFields).map(_.dataType) | ||
| val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, splitExpressions does not work with whole stage codegen, because the input of children expression is not Can you add a test for that (many rows)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a test to check if it will fail compilation for a large stack expression. Testing the actual fallback is nearly impossible. I would need to change something on this line: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala#L359 |
||
| val fields = Seq.tabulate(numFields) { col => | ||
| val index = row * numFields + col | ||
| if (index < values.length) values(index) else Literal(null, dataTypes(col)) | ||
| } | ||
| val eval = CreateStruct(fields).genCode(ctx) | ||
| s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" | ||
| }) | ||
|
|
||
| // Create the collection. | ||
| val wrapperClass = classOf[mutable.WrappedArray[_]].getName | ||
| ctx.addMutableState( | ||
| s"$wrapperClass<InternalRow>", | ||
| ev.value, | ||
| s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") | ||
| ev.copy(code = code, isNull = "false") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A base class for Explode and PosExplode | ||
| * A base class for [[Explode]] and [[PosExplode]]. | ||
| */ | ||
| abstract class ExplodeBase(child: Expression, position: Boolean) | ||
| extends UnaryExpression with Generator with CodegenFallback with Serializable { | ||
| abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable { | ||
| override val inline: Boolean = false | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { | ||
| override def checkInputDataTypes(): TypeCheckResult = child.dataType match { | ||
| case _: ArrayType | _: MapType => | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } else { | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"input to function explode should be array or map type, not ${child.dataType}") | ||
| } | ||
| } | ||
|
|
||
| // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) | ||
| override def elementSchema: StructType = child.dataType match { | ||
| case ArrayType(et, containsNull) => | ||
| if (position) { | ||
| new StructType() | ||
| .add("pos", IntegerType, false) | ||
| .add("pos", IntegerType, nullable = false) | ||
| .add("col", et, containsNull) | ||
| } else { | ||
| new StructType() | ||
|
|
@@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) | |
| case MapType(kt, vt, valueContainsNull) => | ||
| if (position) { | ||
| new StructType() | ||
| .add("pos", IntegerType, false) | ||
| .add("key", kt, false) | ||
| .add("pos", IntegerType, nullable = false) | ||
| .add("key", kt, nullable = false) | ||
| .add("value", vt, valueContainsNull) | ||
| } else { | ||
| new StructType() | ||
| .add("key", kt, false) | ||
| .add("key", kt, nullable = false) | ||
| .add("value", vt, valueContainsNull) | ||
| } | ||
| } | ||
|
|
@@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean) | |
| } | ||
| } | ||
| } | ||
|
|
||
| override def collectionType: DataType = child.dataType | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| child.genCode(ctx) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean) | |
| 20 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class Explode(child: Expression) extends ExplodeBase(child, position = false) | ||
| case class Explode(child: Expression) extends ExplodeBase { | ||
| override val position: Boolean = false | ||
| } | ||
|
|
||
| /** | ||
| * Given an input array produces a sequence of rows for each position and value in the array. | ||
|
|
@@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals | |
| 1 20 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) | ||
| case class PosExplode(child: Expression) extends ExplodeBase { | ||
| override val position = true | ||
| } | ||
|
|
||
| /** | ||
| * Explodes an array of structs into a table. | ||
|
|
@@ -273,20 +335,24 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t | |
| 1 a | ||
| 2 b | ||
| """) | ||
| case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { | ||
| case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator { | ||
| override val inline: Boolean = true | ||
| override val position: Boolean = false | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = child.dataType match { | ||
| case ArrayType(et, _) if et.isInstanceOf[StructType] => | ||
| case ArrayType(st: StructType, _) => | ||
| TypeCheckResult.TypeCheckSuccess | ||
| case _ => | ||
| TypeCheckResult.TypeCheckFailure( | ||
| s"input to function $prettyName should be array of struct type, not ${child.dataType}") | ||
| } | ||
|
|
||
| override def elementSchema: StructType = child.dataType match { | ||
| case ArrayType(et : StructType, _) => et | ||
| case ArrayType(st: StructType, _) => st | ||
| } | ||
|
|
||
| override def collectionType: DataType = child.dataType | ||
|
|
||
| private lazy val numFields = elementSchema.fields.length | ||
|
|
||
| override def eval(input: InternalRow): TraversableOnce[InternalRow] = { | ||
|
|
@@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with | |
| yield inputArray.getStruct(i, numFields) | ||
| } | ||
| } | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| child.genCode(ctx) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collectionTypeis better?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, does we need this interface? Adding a new interface makes codes more complicated, I think.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed a way to make sure that the collection based code path (iteration over ArrayData/MapData) can be easily identified. The other option would be to hard-code all Generators that support this code path, but that seemed just wrong to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for collectionType