Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,43 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx.addMutableState("Object[]", values, s"this.$values = null;")

val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
val isHomogenousStruct = {
var i = 1
val ref = ctx.javaType(schema.fields(0).dataType)
var broken = false || !ctx.isPrimitiveType(ref) || schema.length <=1
while( !broken && i < schema.length) {
if(ctx.javaType(schema.fields(i).dataType) != ref) {
broken = true
}
i +=1
}
!broken
}
val allFields = if(isHomogenousStruct) {
val counter = ctx.freshName("counter")
val converter = convertToSafe(ctx, ctx.getValue(tmp, schema.fields(0).dataType, counter), schema.fields(0).dataType)
s"""
for(int $counter = 0; $counter < ${schema.length}; ++$counter) {
if (!$tmp.isNullAt($counter)) {
${converter.code}
$values[$counter] = ${converter.value};
}
}
"""

}else {
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
s"""
if (!$tmp.isNullAt($i)) {
${converter.code}
$values[$i] = ${converter.value};
}
"""
}
ctx.splitExpressions(tmp, fieldWriters)
}
val allFields = ctx.splitExpressions(tmp, fieldWriters)

val code = s"""
final InternalRow $tmp = $input;
this.$values = new Object[${schema.length}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,56 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

val writeField = dt match {
case t: StructType =>
s"""
val isHomogenousStruct = {
var i = 1
val ref = ctx.javaType(t.fields(0).dataType)
var broken = false || !ctx.isPrimitiveType(ref) || t.length <=1
while( !broken && i < t.length) {
if(ctx.javaType(t.fields(i).dataType) != ref) {
broken = true
}
i +=1
}
!broken
}
if(isHomogenousStruct) {
val counter = ctx.freshName("counter")
val rowWriterChild = ctx.freshName("rowWriterChild")

s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.

final int $tmpCursor = $bufferHolder.cursor;

if (${input.value} instanceof UnsafeRow) {
${writeUnsafeData(ctx, s"((UnsafeRow) ${input.value})", bufferHolder)};
} else {
$rowWriterClass $rowWriterChild = new $rowWriterClass($bufferHolder, ${t.length});
$rowWriterChild.reset();
for(int $counter = 0; $counter < ${t.length}; ++$counter) {
if (${input.value}.isNullAt($index)) {
$rowWriterChild.setNullAt($index);
}else {
$rowWriterChild.write($counter, ${ctx.getValue(input.value, t.fields(0).dataType,
counter)});
}
}
}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""


}else {
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.

final int $tmpCursor = $bufferHolder.cursor;
${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
}

case a @ ArrayType(et, _) =>
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
// For operators that will output domain object, do not insert WholeStageCodegen for it as
// domain object can not be written into unsafe row.
case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
case plan if plan.output.length == 1 &&
plan.output.head.dataType.isInstanceOf[ObjectType] =>
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
case plan: CodegenSupport if supportCodegen(plan) =>
case plan: CodegenSupport => if (supportCodegen(plan)) {
WholeStageCodegenExec(insertInputAdapter(plan))
} else {
plan.withNewChildren(plan.children.map(insertInputAdapter))
}
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}
Expand Down