Skip to content

Commit f86da0f

Browse files
committed
Use TraversableOnce for regular Generators.
1 parent 5cfba19 commit f86da0f

File tree

2 files changed

+88
-50
lines changed

2 files changed

+88
-50
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import com.fasterxml.jackson.core._
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2828
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
29-
import org.apache.spark.sql.catalyst.util.GenericArrayData
3029
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
3130
import org.apache.spark.unsafe.types.UTF8String
3231
import org.apache.spark.util.Utils
@@ -328,9 +327,6 @@ case class GetJsonObject(json: Expression, path: Expression)
328327
// scalastyle:on line.size.limit
329328
case class JsonTuple(children: Seq[Expression]) extends Generator {
330329

331-
// a row is always returned
332-
override def nullable: Boolean = false
333-
334330
// if processing fails this shared value will be returned
335331
@transient private lazy val nullRow: Seq[InternalRow] =
336332
new GenericInternalRow(fieldExpressions.length) :: Nil
@@ -399,17 +395,16 @@ case class JsonTuple(children: Seq[Expression]) extends Generator {
399395
}
400396

401397
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
402-
val arrayDataClass = classOf[GenericArrayData].getName
398+
val iteratorClass = classOf[Iterator[_]].getName
403399
val rowClass = classOf[GenericInternalRow].getName
404-
def newArrayData(p: String): String = s"new $arrayDataClass(new Object[]{new $rowClass($p)})"
405400

406401
// Add an empty row to default to.
407402
val fieldCount = fieldExpressions.length
408403
val nullRow = ctx.freshName("nullRow")
409404
ctx.addMutableState(
410-
arrayDataClass,
405+
rowClass,
411406
nullRow,
412-
s"this.$nullRow = ${newArrayData(fieldCount.toString)};")
407+
s"this.$nullRow = new $rowClass(${fieldCount.toString});")
413408

414409
// Add the field names as a class field and add the foldable field names.
415410
val fieldNames = ctx.freshName("fieldNames")
@@ -432,19 +427,19 @@ case class JsonTuple(children: Seq[Expression]) extends Generator {
432427

433428
// Create the generated code.
434429
val jsonSource = jsonExpr.genCode(ctx)
435-
val result = ctx.freshName("result")
430+
val raw = ctx.freshName("raw")
431+
val row = ctx.freshName("row")
436432
val jsonTupleClass = classOf[JsonTuple].getName
437433
ev.copy(code = s"""
438434
|${jsonSource.code}
439435
|boolean ${ev.isNull} = false;
440-
|ArrayData ${ev.value} = null;
441-
|if (${jsonSource.isNull}) {
442-
| ${ev.value} = $nullRow;
443-
|} else {
436+
|InternalRow $row = $nullRow;
437+
|if (!(${jsonSource.isNull})) {
444438
| ${evalFieldNames.mkString("")}
445-
| Object[] $result = $jsonTupleClass.extractTuple(${jsonSource.value}, $fieldNames);
446-
| ${ev.value} = $result == null ? $nullRow : ${newArrayData(result)};
439+
| Object[] $raw = $jsonTupleClass.extractTuple(${jsonSource.value}, $fieldNames);
440+
| $row = $raw != null ? new $rowClass($raw) : $nullRow;
447441
|}
442+
|$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.single($row);
448443
""".stripMargin)
449444
}
450445
}

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -113,46 +113,48 @@ case class GenerateExec(
113113
}
114114

115115
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
116+
ctx.currentVars = input
117+
ctx.copyResult = true
118+
119+
// Add input rows to the values when we are joining
120+
val values = if (join) {
121+
input
122+
} else {
123+
Seq.empty
124+
}
125+
126+
// Generate the driving expression.
127+
val data = boundGenerator.genCode(ctx)
128+
116129
boundGenerator match {
117-
case e: Explode => codeGen(ctx, e.child, expand = false, input, row)
118-
case g => codeGen(ctx, g, expand = true, input, row)
130+
case e: Explode => codeGenExplode(ctx, e.child, values, data, row)
131+
case g => codeGenTraversableOnce(ctx, g, values, data, row)
119132
}
120133
}
121134

122-
/** Generate code for Generate. */
123-
private def codeGen(
135+
/**
136+
* Generate code for [[Explode]].
137+
*/
138+
private def codeGenExplode(
124139
ctx: CodegenContext,
125140
e: Expression,
126-
expand: Boolean,
127141
input: Seq[ExprCode],
142+
data: ExprCode,
128143
row: ExprCode): String = {
129-
ctx.currentVars = input
130-
ctx.copyResult = true
131-
132-
// Generate the driving expression.
133-
val data = e.genCode(ctx)
134144

135145
// Generate looping variables.
136-
val numOutput = metricTerm(ctx, "numOutputRows")
137146
val index = ctx.freshName("index")
138-
val numElements = ctx.freshName("numElements")
139147

140148
// Add a check if the generate outer flag is true.
141149
val checks = optionalCode(outer, data.isNull)
142-
val (initArrayData, initValues, values) = e.dataType match {
143-
case ArrayType(st: StructType, nullable) if expand =>
144-
val rowCode = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks)
145-
val extendedChecks = checks ++ optionalCode(nullable, rowCode.isNull)
146-
val values = st.fields.toSeq.zipWithIndex.map { case (f, i) =>
147-
codeGenAccessor(ctx, rowCode.value, f.name, s"$i", f.dataType, f.nullable, extendedChecks)
148-
}
149-
("", rowCode.code, values)
150150

151+
// Generate code for either ArrayData or MapData
152+
val (initMapData, values) = e.dataType match {
151153
case ArrayType(dataType, nullable) =>
152-
("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks)))
154+
("", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks)))
153155

154156
case MapType(keyType, valueType, valueContainsNull) =>
155-
// Materialize the key and the value array before we enter the loop.
157+
// Materialize the key and the value arrays before we enter the loop.
156158
val keyArray = ctx.freshName("keyArray")
157159
val valueArray = ctx.freshName("valueArray")
158160
val initArrayData =
@@ -163,34 +165,75 @@ case class GenerateExec(
163165
val values = Seq(
164166
codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks),
165167
codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks))
166-
(initArrayData, "", values)
167-
}
168-
169-
// Determine result vars.
170-
val outputValues = if (join) {
171-
input ++ values
172-
} else {
173-
values
168+
(initArrayData, values)
174169
}
175170

176171
// In case of outer we need to make sure the loop is executed at-least once when the array/map
177172
// contains no input. We do this by setting the looping index to -1 if there is no input,
178173
// evaluation of the array is prevented by a check in the accessor code.
174+
val numElements = ctx.freshName("numElements")
179175
val init = if (outer) s"$numElements == 0 ? -1 : 0" else "0"
176+
val numOutput = metricTerm(ctx, "numOutputRows")
180177
s"""
181178
|${data.code}
182-
|$initArrayData
179+
|$initMapData
183180
|int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements();
184181
|for (int $index = $init; $index < $numElements; $index++) {
185182
| $numOutput.add(1);
186-
| $initValues
187-
| ${consume(ctx, outputValues)}
183+
| ${consume(ctx, input ++ values)}
188184
|}
189185
""".stripMargin
190186
}
191187

192188
/**
193-
* Generate for accessor code for ArrayData and InternalRows.
189+
* Generate code for a regular [[TraversableOnce]] returning [[Generator]].
190+
*/
191+
private def codeGenTraversableOnce(
192+
ctx: CodegenContext,
193+
e: Expression,
194+
input: Seq[ExprCode],
195+
data: ExprCode,
196+
row: ExprCode): String = {
197+
198+
// Generate looping variables.
199+
val iterator = ctx.freshName("iterator")
200+
val hasNext = ctx.freshName("hasNext")
201+
val current = ctx.freshName("row")
202+
203+
// Add a check if the generate outer flag is true.
204+
val checks = optionalCode(outer, hasNext)
205+
val values = e.dataType match {
206+
case ArrayType(st: StructType, nullable) =>
207+
st.fields.toSeq.zipWithIndex.map { case (f, i) =>
208+
codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks)
209+
}
210+
}
211+
212+
// In case of outer we need to make sure the loop is executed at-least-once when the iterator
213+
// contains no input. We do this by adding an 'outer' variable which guarantees execution of
214+
// the first iteration even if there is no input. Evaluation of the iterator is prevented by a
215+
// check in the accessor code.
216+
val hasNextCode = s"$hasNext = $iterator.hasNext()"
217+
val outerVal = ctx.freshName("outer")
218+
def concatIfOuter(s1: String, s2: String): String = s1 + (if (outer) s2 else "")
219+
val init = concatIfOuter(s"boolean $hasNextCode", s", $outerVal = true")
220+
val check = concatIfOuter(hasNext, s"|| $outerVal")
221+
val update = concatIfOuter(hasNextCode, s", $outerVal = false")
222+
val next = if (outer) s"$hasNext ? $iterator.next() : null" else s"$iterator.next()"
223+
val numOutput = metricTerm(ctx, "numOutputRows")
224+
s"""
225+
|${data.code}
226+
|scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
227+
|for ($init; $check; $update) {
228+
| $numOutput.add(1);
229+
| InternalRow $current = (InternalRow) $next;
230+
| ${consume(ctx, input ++ values)}
231+
|}
232+
""".stripMargin
233+
}
234+
235+
/**
236+
* Generate accessor code for ArrayData and InternalRows.
194237
*/
195238
private def codeGenAccessor(
196239
ctx: CodegenContext,
@@ -210,7 +253,7 @@ case class GenerateExec(
210253
s"""
211254
|boolean $isNull = ${checks.mkString(" || ")};
212255
|$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
213-
""".stripMargin
256+
""".stripMargin
214257
ExprCode(code, isNull, value)
215258
} else {
216259
ExprCode(s"$javaType $value = $getter;", "false", value)

0 commit comments

Comments
 (0)