diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715..8f932a80098b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -703,6 +703,33 @@ object ScalaReflection extends ScalaReflection { */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) + def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects { + tpe.dealias match { + case ty if ty <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = ty + arrayClassFromType(elementType) + + case ty => getClassFromType(ty) + } + } + + private def arrayClassFromType(tpe: Type): Class[_] = + ScalaReflection.cleanUpReflectionObjects { + tpe.dealias match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case _ => + // There is probably a better way to do this, but I couldn't find it... + val elementType = getClassFromTypeHandleArray(tpe) + java.lang.reflect.Array.newInstance(elementType, 1).getClass + } + } + case class Schema(dataType: DataType, nullable: Boolean) /** Returns a Sequence of attributes for the given case class type. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba24..dc3c5dfe8f26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.lang.reflect.Constructor + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -474,8 +477,78 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { + + import universe._ + + // Enumerate the pair of logical plan and physical plan which can be transformed via + // 'pass-through', which can be achieved when the difference between parameters on + // primary constructor in both plans is just LogicalPlan vs SparkPlan. + // The map should exclude the pair which 'pass-through' needs to leverage default value of + // constructor parameter. + val passThroughOperators: Map[Class[_ <: LogicalPlan], Class[_ <: SparkPlan]] = Map( + (classOf[logical.DeserializeToObject], classOf[execution.DeserializeToObjectExec]), + (classOf[logical.SerializeFromObject], classOf[execution.SerializeFromObjectExec]), + (classOf[logical.MapPartitions], classOf[execution.MapPartitionsExec]), + (classOf[logical.FlatMapGroupsInR], classOf[execution.FlatMapGroupsInRExec]), + (classOf[logical.FlatMapGroupsInPandas], classOf[execution.python.FlatMapGroupsInPandasExec]), + (classOf[logical.AppendColumnsWithObject], classOf[execution.AppendColumnsWithObjectExec]), + (classOf[logical.MapGroups], classOf[execution.MapGroupsExec]), + (classOf[logical.CoGroup], classOf[execution.CoGroupExec]), + (classOf[logical.Project], classOf[execution.ProjectExec]), + (classOf[logical.Filter], classOf[execution.FilterExec]), + (classOf[logical.Window], classOf[execution.window.WindowExec]), + (classOf[logical.Sample], classOf[execution.SampleExec]) + ) + + lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] = + passThroughOperators.map { case (srcOpCls, _) => + (srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls)) + }.toMap + + lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] = + passThroughOperators.map { case (srcOpCls, tgtOpCls) => + val logicalPlanCls = classOf[LogicalPlan] + val m = runtimeMirror(logicalPlanCls.getClassLoader) + val classSymbol = m.staticClass(logicalPlanCls.getName) + val logicalPlanType = classSymbol.selfType + + val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2) + val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects { + paramTypes.map { + case ty if ty <:< logicalPlanType => + m.staticClass(classOf[SparkPlan].getName).selfType + + case ty => ty + } + } + + val convertedParamClasses = convertedParamTypes.map( + ScalaReflection.getClassFromTypeHandleArray) + val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses) + + constructorOption match { + case Some(const: Constructor[_]) => (srcOpCls, const) + case _ => throw new IllegalStateException( + s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!") + } + }.toMap + + def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = { + val srcClass = src.getClass + require(passThroughOperators.contains(srcClass)) + val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name => + srcClass.getMethod(name).invoke(src) + } + val convertedParamValues = paramValues.map { + case p if p.isInstanceOf[LogicalPlan] => planLater(p.asInstanceOf[LogicalPlan]) + case p => p + } + + val const = operatorToTargetConstructor(srcClass) + const.newInstance(convertedParamValues: _*).asInstanceOf[SparkPlan] + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil case r: RunnableCommand => ExecutedCommandExec(r) :: Nil @@ -497,36 +570,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical except operator should have been replaced by anti-join in the optimizer") - case logical.DeserializeToObject(deserializer, objAttr, child) => - execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil - case logical.SerializeFromObject(serializer, child) => - execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil - case logical.MapPartitions(f, objAttr, child) => - execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil + case src if passThroughOperators.contains(src.getClass) => + createPassThroughOutputPlan(src) :: Nil + case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => execution.MapPartitionsExec( execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil - case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => - execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, - data, objAttr, planLater(child)) :: Nil - case logical.FlatMapGroupsInPandas(grouping, func, output, child) => - execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil - case logical.AppendColumnsWithObject(f, childSer, newSer, child) => - execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil - case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => - execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( f, key, value, grouping, data, output, _, _, _, timeout, child) => execution.MapGroupsExec( f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil - case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => - execution.CoGroupExec( - f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, - planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { @@ -536,18 +593,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Sort(sortExprs, global, child) => execution.SortExec(sortExprs, global, planLater(child)) :: Nil - case logical.Project(projectList, child) => - execution.ProjectExec(projectList, planLater(child)) :: Nil - case logical.Filter(condition, child) => - execution.FilterExec(condition, planLater(child)) :: Nil case f: logical.TypedFilter => execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil - case logical.Sample(lb, ub, withReplacement, seed, child) => - execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => LocalTableScanExec(output, data) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) =>