Skip to content

Commit 139aefa

Browse files
committed
[SPARK-24336][SQL] Support 'pass through' transformation in BasicOperators
1 parent 5be8aab commit 139aefa

File tree

2 files changed

+107
-26
lines changed

2 files changed

+107
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,33 @@ object ScalaReflection extends ScalaReflection {
703703
*/
704704
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
705705

706+
def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects {
707+
tpe.dealias match {
708+
case ty if ty <:< localTypeOf[Array[_]] =>
709+
def arrayClassFromType(tpe: `Type`): Class[_] =
710+
ScalaReflection.cleanUpReflectionObjects {
711+
tpe.dealias match {
712+
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
713+
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
714+
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
715+
case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
716+
case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
717+
case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
718+
case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
719+
case _ =>
720+
// There is probably a better way to do this, but I couldn't find it...
721+
val elementType = getClassFromTypeHandleArray(tpe)
722+
java.lang.reflect.Array.newInstance(elementType, 1).getClass
723+
}
724+
}
725+
726+
val TypeRef(_, _, Seq(elementType)) = ty
727+
arrayClassFromType(elementType)
728+
729+
case ty => getClassFromType(ty)
730+
}
731+
}
732+
706733
case class Schema(dataType: DataType, nullable: Boolean)
707734

708735
/** Returns a Sequence of attributes for the given case class type. */

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import java.lang.reflect.Constructor
21+
2022
import org.apache.spark.rdd.RDD
2123
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
22-
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
25+
import org.apache.spark.sql.catalyst.ScalaReflection._
2326
import org.apache.spark.sql.catalyst.encoders.RowEncoder
2427
import org.apache.spark.sql.catalyst.expressions._
2528
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -474,8 +477,80 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
474477
}
475478
}
476479

477-
// Can we automate these 'pass through' operations?
478480
object BasicOperators extends Strategy {
481+
482+
import universe._
483+
484+
// Enumerate the pair of logical plan and physical plan which can be transformed via
485+
// 'pass-through', which can be achieved when the difference between parameters on
486+
// primary constructor in both plans is just LogicalPlan vs SparkPlan.
487+
// The map should exclude the pair which 'pass-through' needs to leverage default value of
488+
// constructor parameter.
489+
val passThroughOperators: Map[Class[_ <: LogicalPlan], Class[_ <: SparkPlan]] = Map(
490+
(classOf[logical.DeserializeToObject], classOf[execution.DeserializeToObjectExec]),
491+
(classOf[logical.SerializeFromObject], classOf[execution.SerializeFromObjectExec]),
492+
(classOf[logical.MapPartitions], classOf[execution.MapPartitionsExec]),
493+
(classOf[logical.FlatMapGroupsInR], classOf[execution.FlatMapGroupsInRExec]),
494+
(classOf[logical.FlatMapGroupsInPandas], classOf[execution.python.FlatMapGroupsInPandasExec]),
495+
(classOf[logical.AppendColumnsWithObject], classOf[execution.AppendColumnsWithObjectExec]),
496+
(classOf[logical.MapGroups], classOf[execution.MapGroupsExec]),
497+
(classOf[logical.CoGroup], classOf[execution.CoGroupExec]),
498+
(classOf[logical.Project], classOf[execution.ProjectExec]),
499+
(classOf[logical.Filter], classOf[execution.FilterExec]),
500+
(classOf[logical.Window], classOf[execution.window.WindowExec]),
501+
(classOf[logical.Sample], classOf[execution.SampleExec])
502+
)
503+
504+
lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] =
505+
passThroughOperators.map {
506+
case (srcOpCls, _) =>
507+
(srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls))
508+
}.toMap
509+
510+
lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] =
511+
passThroughOperators.map {
512+
case (srcOpCls, tgtOpCls) =>
513+
val logicalPlanCls = classOf[LogicalPlan]
514+
val m = runtimeMirror(logicalPlanCls.getClassLoader)
515+
val classSymbol = m.staticClass(logicalPlanCls.getName)
516+
val logicalPlanType = classSymbol.selfType
517+
518+
val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2)
519+
val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects {
520+
paramTypes.map {
521+
case ty if ty <:< logicalPlanType =>
522+
m.staticClass(classOf[SparkPlan].getName).selfType
523+
524+
case ty => ty
525+
}
526+
}
527+
528+
val convertedParamClasses = convertedParamTypes.map(
529+
ScalaReflection.getClassFromTypeHandleArray)
530+
val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses)
531+
532+
constructorOption match {
533+
case Some(const: Constructor[_]) => (srcOpCls, const)
534+
case _ => throw new IllegalStateException(
535+
s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!")
536+
}
537+
}.toMap
538+
539+
def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = {
540+
val srcClass = src.getClass
541+
require(passThroughOperators.contains(srcClass))
542+
val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name =>
543+
srcClass.getMethod(name).invoke(src)
544+
}
545+
val convertedParamValues = paramValues.map {
546+
case p if p.isInstanceOf[LogicalPlan] => planLater(p.asInstanceOf[LogicalPlan])
547+
case p => p
548+
}
549+
550+
val const = operatorToTargetConstructor(srcClass)
551+
const.newInstance(convertedParamValues: _*).asInstanceOf[SparkPlan]
552+
}
553+
479554
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
480555
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
481556
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
@@ -497,36 +572,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
497572
throw new IllegalStateException(
498573
"logical except operator should have been replaced by anti-join in the optimizer")
499574

500-
case logical.DeserializeToObject(deserializer, objAttr, child) =>
501-
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
502-
case logical.SerializeFromObject(serializer, child) =>
503-
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
504-
case logical.MapPartitions(f, objAttr, child) =>
505-
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
575+
case src if passThroughOperators.contains(src.getClass) =>
576+
createPassThroughOutputPlan(src) :: Nil
577+
506578
case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
507579
execution.MapPartitionsExec(
508580
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
509581
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
510582
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
511583
data, objAttr, planLater(child)) :: Nil
512-
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
513-
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
514584
case logical.MapElements(f, _, _, objAttr, child) =>
515585
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
516586
case logical.AppendColumns(f, _, _, in, out, child) =>
517587
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
518-
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
519-
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
520-
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
521-
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
522588
case logical.FlatMapGroupsWithState(
523589
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
524590
execution.MapGroupsExec(
525591
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
526-
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
527-
execution.CoGroupExec(
528-
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
529-
planLater(left), planLater(right)) :: Nil
530592

531593
case logical.Repartition(numPartitions, shuffle, child) =>
532594
if (shuffle) {
@@ -536,18 +598,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
536598
}
537599
case logical.Sort(sortExprs, global, child) =>
538600
execution.SortExec(sortExprs, global, planLater(child)) :: Nil
539-
case logical.Project(projectList, child) =>
540-
execution.ProjectExec(projectList, planLater(child)) :: Nil
541-
case logical.Filter(condition, child) =>
542-
execution.FilterExec(condition, planLater(child)) :: Nil
543601
case f: logical.TypedFilter =>
544602
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
545603
case e @ logical.Expand(_, _, child) =>
546604
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
547-
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
548-
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
549-
case logical.Sample(lb, ub, withReplacement, seed, child) =>
550-
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
551605
case logical.LocalRelation(output, data, _) =>
552606
LocalTableScanExec(output, data) :: Nil
553607
case logical.LocalLimit(IntegerLiteral(limit), child) =>

0 commit comments

Comments
 (0)