Skip to content

Commit 3d6d818

Browse files
ahshahidsumwale
authored andcommitted
[SNAPPYDATA] Bootstrap perf (#16)
Change involves: 1) Reducing the generated code size when writing struct having all fields of same data type. 2) Fixing an issue in WholeStageCodeGenExec, where a plan supporting CodeGen was not being prefixed by InputAdapter in case, the node did not participate in whole stage code gen.
1 parent 091f17b commit 3d6d818

File tree

3 files changed

+94
-22
lines changed

3 files changed

+94
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,52 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5454

5555
val rowClass = classOf[GenericInternalRow].getName
5656

57-
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
57+
val isHomogenousStruct = {
58+
var i = 1
59+
val ref = CodeGenerator.javaType(schema.fields(0).dataType)
60+
var broken = !CodeGenerator.isPrimitiveType(ref) || schema.length <= 1
61+
while (!broken && i < schema.length) {
62+
if (CodeGenerator.javaType(schema.fields(i).dataType) != ref) {
63+
broken = true
64+
}
65+
i += 1
66+
}
67+
!broken
68+
}
69+
val allFields = if (isHomogenousStruct) {
70+
val counter = ctx.freshName("counter")
71+
val dt = schema.fields(0).dataType
5872
val converter = convertToSafe(
5973
ctx,
6074
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
6175
dt)
6276
s"""
63-
if (!$tmpInput.isNullAt($i)) {
64-
${converter.code}
65-
$values[$i] = ${converter.value};
77+
for (int $counter = 0; $counter < ${schema.length}; ++$counter) {
78+
if (!$tmpInput.isNullAt($counter)) {
79+
${converter.code}
80+
$values[$counter] = ${converter.value};
81+
}
6682
}
6783
"""
84+
} else {
85+
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
86+
val converter = convertToSafe(
87+
ctx,
88+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
89+
dt)
90+
s"""
91+
if (!$tmpInput.isNullAt($i)) {
92+
${converter.code}
93+
$values[$i] = ${converter.value};
94+
}
95+
"""
96+
}
97+
ctx.splitExpressions(
98+
expressions = fieldWriters,
99+
funcName = "writeFields",
100+
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
101+
)
68102
}
69-
val allFields = ctx.splitExpressions(
70-
expressions = fieldWriters,
71-
funcName = "writeFields",
72-
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
73-
)
74103
val code =
75104
code"""
76105
|final InternalRow $tmpInput = $input;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,58 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
6666
val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
6767
v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});")
6868
val previousCursor = ctx.freshName("previousCursor")
69-
s"""
70-
|final InternalRow $tmpInput = $input;
71-
|if ($tmpInput instanceof UnsafeRow) {
72-
| $rowWriter.write($index, (UnsafeRow) $tmpInput);
73-
|} else {
74-
| // Remember the current cursor so that we can calculate how many bytes are
75-
| // written later.
76-
| final int $previousCursor = $rowWriter.cursor();
77-
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)}
78-
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
79-
|}
80-
""".stripMargin
69+
70+
val isHomogenousStruct = {
71+
var i = 1
72+
val ref = CodeGenerator.javaType(t.fields(0).dataType)
73+
var broken = !CodeGenerator.isPrimitiveType(ref) || t.length <=1
74+
while (!broken && i < t.length) {
75+
if (CodeGenerator.javaType(t.fields(i).dataType) != ref) {
76+
broken = true
77+
}
78+
i +=1
79+
}
80+
!broken
81+
}
82+
if (isHomogenousStruct) {
83+
val counter = ctx.freshName("counter")
84+
val rowWriterChild = ctx.freshName("rowWriterChild")
85+
val dt = t.fields(0).dataType
86+
87+
s"""
88+
|final InternalRow $tmpInput = $input;
89+
|if ($tmpInput instanceof UnsafeRow) {
90+
| $rowWriter.write($index, (UnsafeRow) $tmpInput);
91+
|} else {
92+
| // Remember the current cursor so that we can calculate how many bytes are
93+
| // written later.
94+
| final int $previousCursor = $rowWriter.cursor();
95+
| $rowWriterClass $rowWriterChild = new $rowWriterClass($rowWriter, ${t.length});
96+
| $rowWriterChild.reset();
97+
| for (int $counter = 0; $counter < ${t.length}; $counter++) {
98+
| if ($tmpInput.isNullAt($index)) {
99+
| $rowWriterChild.setNullAt($index);
100+
| } else {
101+
| $rowWriterChild.write($counter, ${CodeGenerator.getValue(tmpInput, dt, counter)});
102+
| }
103+
| }
104+
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
105+
|}
106+
""".stripMargin
107+
} else {
108+
s"""
109+
|final InternalRow $tmpInput = $input;
110+
|if ($tmpInput instanceof UnsafeRow) {
111+
| $rowWriter.write($index, (UnsafeRow) $tmpInput);
112+
|} else {
113+
| // Remember the current cursor so that we can calculate how many bytes are
114+
| // written later.
115+
| final int $previousCursor = $rowWriter.cursor();
116+
| ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas, structRowWriter)}
117+
| $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
118+
|}
119+
""".stripMargin
120+
}
81121
}
82122

83123
private def writeExpressionsToBuffer(

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,11 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
744744
// domain object can not be written into unsafe row.
745745
case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
746746
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
747-
case plan: CodegenSupport if supportCodegen(plan) =>
747+
case plan: CodegenSupport => if (supportCodegen(plan)) {
748748
WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
749+
} else {
750+
plan.withNewChildren(plan.children.map(insertInputAdapter))
751+
}
749752
case other =>
750753
other.withNewChildren(other.children.map(insertWholeStageCodegen))
751754
}

0 commit comments

Comments
 (0)