1717
1818package org .apache .spark .sql .execution
1919
20+ import java .lang .reflect .Constructor
21+
2022import org .apache .spark .rdd .RDD
2123import org .apache .spark .sql .{execution , AnalysisException , Strategy }
22- import org .apache .spark .sql .catalyst .InternalRow
24+ import org .apache .spark .sql .catalyst .{InternalRow , ScalaReflection }
25+ import org .apache .spark .sql .catalyst .ScalaReflection ._
2326import org .apache .spark .sql .catalyst .encoders .RowEncoder
2427import org .apache .spark .sql .catalyst .expressions ._
2528import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
@@ -474,8 +477,80 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
474477 }
475478 }
476479
477- // Can we automate these 'pass through' operations?
478480 object BasicOperators extends Strategy {
481+
482+ import universe ._
483+
484+ // Enumerate the pair of logical plan and physical plan which can be transformed via
485+ // 'pass-through', which can be achieved when the difference between parameters on
486+ // primary constructor in both plans is just LogicalPlan vs SparkPlan.
487+ // The map should exclude the pair which 'pass-through' needs to leverage default value of
488+ // constructor parameter.
489+ val passThroughOperators : Map [Class [_ <: LogicalPlan ], Class [_ <: SparkPlan ]] = Map (
490+ (classOf [logical.DeserializeToObject ], classOf [execution.DeserializeToObjectExec ]),
491+ (classOf [logical.SerializeFromObject ], classOf [execution.SerializeFromObjectExec ]),
492+ (classOf [logical.MapPartitions ], classOf [execution.MapPartitionsExec ]),
493+ (classOf [logical.FlatMapGroupsInR ], classOf [execution.FlatMapGroupsInRExec ]),
494+ (classOf [logical.FlatMapGroupsInPandas ], classOf [execution.python.FlatMapGroupsInPandasExec ]),
495+ (classOf [logical.AppendColumnsWithObject ], classOf [execution.AppendColumnsWithObjectExec ]),
496+ (classOf [logical.MapGroups ], classOf [execution.MapGroupsExec ]),
497+ (classOf [logical.CoGroup ], classOf [execution.CoGroupExec ]),
498+ (classOf [logical.Project ], classOf [execution.ProjectExec ]),
499+ (classOf [logical.Filter ], classOf [execution.FilterExec ]),
500+ (classOf [logical.Window ], classOf [execution.window.WindowExec ]),
501+ (classOf [logical.Sample ], classOf [execution.SampleExec ])
502+ )
503+
504+ lazy val operatorToConstructorParameters : Map [Class [_ <: LogicalPlan ], Seq [(String , Type )]] =
505+ passThroughOperators.map {
506+ case (srcOpCls, _) =>
507+ (srcOpCls, ScalaReflection .getConstructorParameters(srcOpCls))
508+ }.toMap
509+
510+ lazy val operatorToTargetConstructor : Map [Class [_ <: LogicalPlan ], Constructor [_]] =
511+ passThroughOperators.map {
512+ case (srcOpCls, tgtOpCls) =>
513+ val logicalPlanCls = classOf [LogicalPlan ]
514+ val m = runtimeMirror(logicalPlanCls.getClassLoader)
515+ val classSymbol = m.staticClass(logicalPlanCls.getName)
516+ val logicalPlanType = classSymbol.selfType
517+
518+ val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2)
519+ val convertedParamTypes = ScalaReflection .cleanUpReflectionObjects {
520+ paramTypes.map {
521+ case ty if ty <:< logicalPlanType =>
522+ m.staticClass(classOf [SparkPlan ].getName).selfType
523+
524+ case ty => ty
525+ }
526+ }
527+
528+ val convertedParamClasses = convertedParamTypes.map(
529+ ScalaReflection .getClassFromTypeHandleArray)
530+ val constructorOption = ScalaReflection .findConstructor(tgtOpCls, convertedParamClasses)
531+
532+ constructorOption match {
533+ case Some (const : Constructor [_]) => (srcOpCls, const)
534+ case _ => throw new IllegalStateException (
535+ s " Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}! " )
536+ }
537+ }.toMap
538+
539+ def createPassThroughOutputPlan (src : LogicalPlan ): SparkPlan = {
540+ val srcClass = src.getClass
541+ require(passThroughOperators.contains(srcClass))
542+ val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name =>
543+ srcClass.getMethod(name).invoke(src)
544+ }
545+ val convertedParamValues = paramValues.map {
546+ case p if p.isInstanceOf [LogicalPlan ] => planLater(p.asInstanceOf [LogicalPlan ])
547+ case p => p
548+ }
549+
550+ val const = operatorToTargetConstructor(srcClass)
551+ const.newInstance(convertedParamValues : _* ).asInstanceOf [SparkPlan ]
552+ }
553+
479554 def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
480555 case d : DataWritingCommand => DataWritingCommandExec (d, planLater(d.query)) :: Nil
481556 case r : RunnableCommand => ExecutedCommandExec (r) :: Nil
@@ -497,36 +572,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
497572 throw new IllegalStateException (
498573 " logical except operator should have been replaced by anti-join in the optimizer" )
499574
500- case logical.DeserializeToObject (deserializer, objAttr, child) =>
501- execution.DeserializeToObjectExec (deserializer, objAttr, planLater(child)) :: Nil
502- case logical.SerializeFromObject (serializer, child) =>
503- execution.SerializeFromObjectExec (serializer, planLater(child)) :: Nil
504- case logical.MapPartitions (f, objAttr, child) =>
505- execution.MapPartitionsExec (f, objAttr, planLater(child)) :: Nil
575+ case src if passThroughOperators.contains(src.getClass) =>
576+ createPassThroughOutputPlan(src) :: Nil
577+
506578 case logical.MapPartitionsInR (f, p, b, is, os, objAttr, child) =>
507579 execution.MapPartitionsExec (
508580 execution.r.MapPartitionsRWrapper (f, p, b, is, os), objAttr, planLater(child)) :: Nil
509581 case logical.FlatMapGroupsInR (f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
510582 execution.FlatMapGroupsInRExec (f, p, b, is, os, key, value, grouping,
511583 data, objAttr, planLater(child)) :: Nil
512- case logical.FlatMapGroupsInPandas (grouping, func, output, child) =>
513- execution.python.FlatMapGroupsInPandasExec (grouping, func, output, planLater(child)) :: Nil
514584 case logical.MapElements (f, _, _, objAttr, child) =>
515585 execution.MapElementsExec (f, objAttr, planLater(child)) :: Nil
516586 case logical.AppendColumns (f, _, _, in, out, child) =>
517587 execution.AppendColumnsExec (f, in, out, planLater(child)) :: Nil
518- case logical.AppendColumnsWithObject (f, childSer, newSer, child) =>
519- execution.AppendColumnsWithObjectExec (f, childSer, newSer, planLater(child)) :: Nil
520- case logical.MapGroups (f, key, value, grouping, data, objAttr, child) =>
521- execution.MapGroupsExec (f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
522588 case logical.FlatMapGroupsWithState (
523589 f, key, value, grouping, data, output, _, _, _, timeout, child) =>
524590 execution.MapGroupsExec (
525591 f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
526- case logical.CoGroup (f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
527- execution.CoGroupExec (
528- f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
529- planLater(left), planLater(right)) :: Nil
530592
531593 case logical.Repartition (numPartitions, shuffle, child) =>
532594 if (shuffle) {
@@ -536,18 +598,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
536598 }
537599 case logical.Sort (sortExprs, global, child) =>
538600 execution.SortExec (sortExprs, global, planLater(child)) :: Nil
539- case logical.Project (projectList, child) =>
540- execution.ProjectExec (projectList, planLater(child)) :: Nil
541- case logical.Filter (condition, child) =>
542- execution.FilterExec (condition, planLater(child)) :: Nil
543601 case f : logical.TypedFilter =>
544602 execution.FilterExec (f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
545603 case e @ logical.Expand (_, _, child) =>
546604 execution.ExpandExec (e.projections, e.output, planLater(child)) :: Nil
547- case logical.Window (windowExprs, partitionSpec, orderSpec, child) =>
548- execution.window.WindowExec (windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
549- case logical.Sample (lb, ub, withReplacement, seed, child) =>
550- execution.SampleExec (lb, ub, withReplacement, seed, planLater(child)) :: Nil
551605 case logical.LocalRelation (output, data, _) =>
552606 LocalTableScanExec (output, data) :: Nil
553607 case logical.LocalLimit (IntegerLiteral (limit), child) =>
0 commit comments